diff --git a/.github/.golangci.yml b/.github/.golangci.yml index f3c0a5f..c3a684d 100644 --- a/.github/.golangci.yml +++ b/.github/.golangci.yml @@ -5,56 +5,67 @@ linters: - asciicheck - bidichk - bodyclose + - canonicalheader - containedctx - contextcheck + - copyloopvar - cyclop - decorder #- depguard - dogsled - dupl + - dupword - durationcheck + - err113 - errcheck - errchkjson - errname - errorlint - - execinquery - - exhaustive + #- exhaustive - exhaustruct - exportloopref + - fatcontext - forbidigo - forcetypeassert - funlen - gci + - ginkgolinter + - gocheckcompilerdirectives #- gochecknoglobals - #- gochecknoinits + - gochecknoinits + - gochecksumtype - gocognit - goconst - gocritic - gocyclo - godot - godox - - goerr113 - gofmt - gofumpt - goheader - goimports - #- gomnd - gomoddirectives - gomodguard - goprintffuncname - gosec - gosimple + - gosmopolitan - govet - grouper - importas + - inamedparam - ineffassign - interfacebloat + - intrange - ireturn - lll - - logrlint + - loggercheck - maintidx - makezero + - mirror - misspell + #- mnd + - musttag - nakedret - nestif - nilerr @@ -62,21 +73,27 @@ linters: - nlreturn - noctx - nolintlint - #- nonamedreturns + - nonamedreturns - nosprintfhostport - paralleltest + - perfsprint - prealloc - predeclared - promlinter + - protogetter - reassign - revive - rowserrcheck + - sloglint + - spancheck - sqlclosecheck - staticcheck - stylecheck + - tagalign - tagliatelle - tenv - #- testableexamples + - testableexamples + - testifylint - testpackage - thelper - tparallel @@ -88,19 +105,11 @@ linters: #- varnamelen - wastedassign - whitespace - - wrapcheck + #- wrapcheck - wsl - presets: - - bugs - - unused - fast: false + - zerologlint linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 12 - skip-tests: true dupl: threshold: 100 errcheck: @@ -135,6 +144,8 @@ linters-settings: - opinionated - performance - style + disabled-checks: + - unnamedResult gocyclo: min-complexity: 15 godox: @@ -148,23 +159,12 @@ linters-settings: simplify: true goimports: local-prefixes: github.com/bytemare/voprf - golint: - min-confidence: 0 - gomnd: - checks: - - argument - - case - - condition - - operation - - return - - assign - ignored-functions: - - 'Import' - - 'Deserialize' - - 'DeriveKeyPair' - - 'i2osp2' + gosimple: + checks: [ "all" ] govet: - check-shadowing: true + settings: + shadow: + strict: true disable-all: true enable: - asmdecl @@ -209,13 +209,23 @@ linters-settings: tab-width: 4 misspell: locale: US + mnd: + checks: + - argument + - case + - condition + - operation + - return + - assign + #ignored-functions: + # - 'nist.setMapping' + # - 'big.NewInt' + # - 'hash2curve.HashToFieldXMD' nlreturn: block-size: 2 prealloc: simple: false for-loops: true - unused: - check-exported: false whitespace: multi-if: false multi-func: false @@ -236,17 +246,18 @@ issues: # But independently from this option we use default exclude patterns, # it can be disabled by `exclude-use-default: false`. To list all # excluded by default patterns execute `golangci-lint run --help` - exclude: - - "should have a package comment, unless it's in another file for this package" + #exclude: + #- "should have a package comment, unless it's in another file for this package" + #- "do not define dynamic errors, use wrapped static errors instead" + #- "missing cases in switch of type Group: maxID" - exclude-rules: - # Exclude some linters from running on tests files. - - path: _test\.go - linters: - - gocyclo - - errcheck - - dupl - - gosec + #exclude-rules: + # - path: internal/hash.go + # linters: + # - errcheck + # - path: internal/tag/strings.go + # linters: + # - gosec max-issues-per-linter: 0 max-same-issues: 0 @@ -260,5 +271,7 @@ issues: run: tests: false -output: - format: github-actions \ No newline at end of file +#output: +# formats: +# - format: github-actions +# show-stats: true \ No newline at end of file diff --git a/README.md b/README.md index 2a2105d..b6cd164 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/bytemare/voprf.svg)](https://pkg.go.dev/github.com/bytemare/voprf) [![codecov](https://codecov.io/gh/bytemare/voprf/branch/main/graph/badge.svg?token=5bQfB0OctA)](https://codecov.io/gh/bytemare/voprf) -Package voprf provides abstracted access to Oblivious Pseudorandom Functions (OPRF) over Elliptic Curves as specified in -[RFC9497](https://datatracker.ietf.org/doc/rfc9497) and fully supports the OPRF, VOPRF, and POPRF protocols. +Package voprf implements [RFC9497](https://datatracker.ietf.org/doc/rfc9497) and provides Oblivious Pseudorandom Functions +(OPRF) over Elliptic Curves, and fully supports the OPRF, VOPRF, and POPRF protocols. ## Versioning diff --git a/client.go b/client.go deleted file mode 100644 index 8963d45..0000000 --- a/client.go +++ /dev/null @@ -1,228 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -package voprf - -import ( - "errors" - "fmt" - - group "github.com/bytemare/crypto" -) - -var ( - errArrayLength = errors.New("blinding init failed, non-nil array of incompatible length") - errNilProofC = errors.New("c proof is nil or empty") - errNilProofS = errors.New("s proof is nil or empty") - errInvalidNumElements = errors.New("invalid number of element ") - errInvalidInput = errors.New( - "invalid input - OPRF input deterministically maps to the group identity element", - ) -) - -// Client represents the Client/Verifier party in a (V)OPRF protocol session, -// and exposes relevant functions for its execution. -type Client struct { - tweakedKey *group.Element - serverPublicKey *group.Element - *oprf - - input [][]byte - blind []*group.Scalar - blindedElement []*group.Element -} - -// SetBlinds sets the inner blinds to those given as input. -func (c *Client) SetBlinds(blind []*group.Scalar) { - c.blind = blind -} - -// Blind blinds, or masks, the input with a preset or new random blinding element. -func (c *Client) Blind(input, info []byte) []byte { - if err := c.initBlinding(1); err != nil { - panic(err) - } - - c.innerBlind(input, info, 0) - - return c.blindedElement[0].Encode() -} - -// BlindBatch allows blinding of batched input. If internal blinds are not set, new ones are created. In either case, -// the blinds are returned, and can safely be ignored if not needed externally. Subsequent calls on unblinding functions -// will automatically use the internal blinds, unless specified otherwise through unblindBatchWithBlinds(). -func (c *Client) BlindBatch(input [][]byte, info []byte) (blinds, blindedElements [][]byte, err error) { - if err := c.initBlinding(len(input)); err != nil { - return nil, nil, err - } - - blinds = make([][]byte, len(input)) - blindedElements = make([][]byte, len(input)) - - for i, in := range input { - c.innerBlind(in, info, i) - // Only keep the blinds in a multiplicative mode - if c.blind[i] != nil { - blinds[i] = c.blind[i].Encode() - } - - blindedElements[i] = c.blindedElement[i].Encode() - } - - return blinds, blindedElements, nil -} - -// BlindBatchWithBlinds enables blinding batches while specifying which blinds to use. -func (c *Client) BlindBatchWithBlinds(blinds, input [][]byte, info []byte) ([][]byte, error) { - if len(blinds) != len(input) { - return nil, errParamInputEqualLen - } - - if err := c.initBlinding(len(blinds)); err != nil { - return nil, err - } - - blindedElements := make([][]byte, len(input)) - - for i, blind := range blinds { - s := c.group.NewScalar() - if err := s.Decode(blind); err != nil { - return nil, fmt.Errorf("input blind %d decoding errored with %w", i, err) - } - - c.input[i] = input[i] - c.blind[i] = s - c.innerBlind(input[i], info, i) - blindedElements[i] = c.blindedElement[i].Encode() - } - - return blindedElements, nil -} - -// Finalize finalizes the protocol execution by verifying the proof if necessary, -// unblinding the evaluated element, and hashing the transcript. -func (c *Client) Finalize(e *Evaluation, info []byte) ([]byte, error) { - output, err := c.FinalizeBatch(e, info) - if err != nil { - return nil, err - } - - return output[0], nil -} - -// FinalizeBatch finalizes the protocol execution by verifying the proof if necessary, -// unblinding the evaluated elements, and hashing the transcript. -func (c *Client) FinalizeBatch(e *Evaluation, info []byte) ([][]byte, error) { - if len(e.Elements) != len(c.input) { - return nil, errParamFinalizeLen - } - - ev, err := e.deserialize(c.group) - if err != nil { - return nil, err - } - - if len(ev.elements) != len(c.blindedElement) { - return nil, errInvalidNumElements - } - - if c.oprf.mode == OPRF || c.oprf.mode == VOPRF { - info = nil - } - - if c.oprf.mode == VOPRF || c.oprf.mode == POPRF { - if err := c.verifyProof(ev); err != nil { - return nil, err - } - } - - out := make([][]byte, len(c.input)) - - for i, ee := range ev.elements { - u := c.unblind(ee, c.blind[i]) - out[i] = c.hashTranscript(c.input[i], info, u.Encode()) - } - - return out, nil -} - -func (c *Client) initBlinding(length int) error { - if len(c.input) == 0 { - c.input = make([][]byte, length) - } else if len(c.input) != length { - return errArrayLength - } - - if len(c.blind) == 0 { - c.blind = make([]*group.Scalar, length) - } else if len(c.blind) != length { - return errArrayLength - } - - if len(c.blindedElement) == 0 { - c.blindedElement = make([]*group.Element, length) - } else if len(c.blindedElement) != length { - return errArrayLength - } - - return nil -} - -func (c *Client) verifyProof(ev *evaluation) error { - if ev.proofC == nil { - return errNilProofC - } - - if ev.proofS == nil { - return errNilProofS - } - - var pk *group.Element - var cs, ds []*group.Element - - if c.oprf.mode == VOPRF { - cs, ds = c.blindedElement, ev.elements - pk = c.serverPublicKey - } else { // POPRF - cs, ds = ev.elements, c.blindedElement - pk = c.tweakedKey - } - - return c.oprf.verifyProof(ev, pk, cs, ds) -} - -func (c *Client) innerBlind(input, info []byte, index int) { - if c.blind[index] == nil { - c.blind[index] = c.group.NewScalar().Random() - } - - c.input[index] = input - - if c.oprf.mode == POPRF { - m := c.pTag(info) - - t := c.group.Base().Multiply(m).Add(c.serverPublicKey) - if t.IsIdentity() { - panic(errInvalidInput) - } - - c.tweakedKey = t - } - - p := c.HashToGroup(input) - if p.IsIdentity() { - panic(errInvalidInput) - } - - c.blindedElement[index] = p.Multiply(c.blind[index]) -} - -func (c *Client) unblind(evaluated *group.Element, blind *group.Scalar) *group.Element { - inv := blind.Copy().Invert() - return evaluated.Multiply(inv) -} diff --git a/client_state.go b/client_state.go deleted file mode 100644 index 4e46ae3..0000000 --- a/client_state.go +++ /dev/null @@ -1,176 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -package voprf - -import ( - "fmt" - - group "github.com/bytemare/crypto" -) - -// State represents a client's state, allowing internal values to be exported and imported to resume a session. -type State struct { - Identifier Ciphersuite `json:"s"` - TweakedKey []byte `json:"t,omitempty"` - ServerPublicKey []byte `json:"p,omitempty"` - Input [][]byte `json:"i"` - Blind [][]byte `json:"r"` - Blinded [][]byte `json:"d"` - Mode Mode `json:"m"` -} - -// Export extracts the client's internal values that can be imported in another client for session resumption. -func (c *Client) Export() *State { - s := &State{ - Identifier: c.ciphersuite, - TweakedKey: nil, - ServerPublicKey: nil, - Input: nil, - Blind: nil, - Blinded: nil, - Mode: c.mode, - } - - if c.serverPublicKey != nil { - s.ServerPublicKey = c.serverPublicKey.Encode() - } - - if c.tweakedKey != nil { - s.TweakedKey = c.tweakedKey.Encode() - } - - if len(c.input) != len(c.blind) { - panic("different number of input and blind values") - } - - s.Input = make([][]byte, len(c.input)) - s.Blind = make([][]byte, len(c.blind)) - s.Blinded = make([][]byte, len(c.blindedElement)) - - for i := 0; i < len(c.input); i++ { - s.Input[i] = make([]byte, len(c.input[i])) - copy(s.Input[i], c.input[i]) - s.Blind[i] = c.blind[i].Encode() - s.Blinded[i] = c.blindedElement[i].Encode() - } - - return s -} - -// RecoverClient returns a Client recovered form the state, from which a session can be resumed. -func (s *State) RecoverClient() (*Client, error) { - if s.Mode != OPRF && s.Mode != VOPRF && s.Mode != POPRF { - return nil, errParamInvalidMode - } - - if !s.Identifier.Available() { - return nil, errParamInvalidID - } - - c := s.Identifier.client(s.Mode) - - if err := importPrecheck(s); err != nil { - return nil, err - } - - c.oprf = s.Identifier.new(s.Mode) - - if err := c.importTweakedKey(s); err != nil { - return nil, err - } - - if err := c.importPublicKey(s); err != nil { - return nil, err - } - - if err := c.importBlinds(s); err != nil { - return nil, err - } - - if err := c.importBlinded(s); err != nil { - return nil, err - } - - return c, nil -} - -func importPrecheck(state *State) error { - if len(state.Input) != len(state.Blinded) { - return errStateDiffInput - } - - if len(state.Blinded) != 0 && len(state.Blinded) != len(state.Blind) { - return errStateDiffBlind - } - - if state.Mode == VOPRF && state.ServerPublicKey == nil { - return errStateNoPubKey - } - - return nil -} - -func (c *Client) importTweakedKey(state *State) error { - if state.TweakedKey != nil { - t := c.group.NewElement() - if err := t.Decode(state.TweakedKey); err != nil { - return fmt.Errorf("tweaked key - %w", err) - } - - c.tweakedKey = t - } - - return nil -} - -func (c *Client) importPublicKey(state *State) error { - if state.ServerPublicKey != nil { - pk := c.group.NewElement() - if err := pk.Decode(state.ServerPublicKey); err != nil { - return fmt.Errorf("server public key - %w", err) - } - - c.serverPublicKey = pk - } - - return nil -} - -func (c *Client) importBlinds(state *State) error { - c.blind = make([]*group.Scalar, len(state.Blind)) - for i := 0; i < len(state.Blind); i++ { - blind := c.group.NewScalar() - if err := blind.Decode(state.Blind[i]); err != nil { - return fmt.Errorf("blind %d - %w", i, err) - } - - c.blind[i] = blind - } - - return nil -} - -func (c *Client) importBlinded(state *State) error { - c.input = make([][]byte, len(state.Input)) - c.blindedElement = make([]*group.Element, len(state.Blinded)) - - for i := 0; i < len(state.Blinded); i++ { - c.input[i] = make([]byte, len(state.Input[i])) - copy(c.input[i], state.Input[i]) - - blinded := c.group.NewElement() - if err := blinded.Decode(state.Blinded[i]); err != nil { - return fmt.Errorf("invalid blinded element: %w", err) - } - - c.blindedElement[i] = blinded - } - - return nil -} diff --git a/doc.go b/doc.go deleted file mode 100644 index d0fca86..0000000 --- a/doc.go +++ /dev/null @@ -1,13 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -// Package voprf provides abstracted access to Oblivious Pseudorandom Functions (OPRF) -// and VOPRF Oblivious Pseudorandom Functions (VOPRF) using Elliptic Curves (EC(V)OPRF). -// -// This implements RFC9497. -package voprf diff --git a/encoding.go b/encoding.go new file mode 100644 index 0000000..296d7a1 --- /dev/null +++ b/encoding.go @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +// Package voprf implements RFC9497 and provides abstracted access to Oblivious Pseudorandom Functions (OPRF) and +// Threshold Oblivious Pseudorandom Functions (TOPRF) using Elliptic Curve Prime Order Groups (EC-OPRF). +// For VOPRF and POPRF use the github.com/bytemare/oprf/voprf package. +package voprf + +import group "github.com/bytemare/crypto" + +// DecodeElement decodes e to an element in the group. +func (c Ciphersuite) DecodeElement(e []byte) (*group.Element, error) { + result := group.Group(c).NewElement() + + if err := result.Decode(e); err != nil { + return nil, err + } + + return result, nil +} + +// DecodeScalar decodes s to a scalar in the group. +func (c Ciphersuite) DecodeScalar(s []byte) (*group.Scalar, error) { + result := group.Group(c).NewScalar() + + if err := result.Decode(s); err != nil { + return nil, err + } + + return result, nil +} diff --git a/errors.go b/errors.go deleted file mode 100644 index ad506ba..0000000 --- a/errors.go +++ /dev/null @@ -1,30 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -package voprf - -import "errors" - -var ( - errParamInvalidMode = errors.New("invalid OPRF mode") - errParamInvalidID = errors.New("invalid Ciphersuite") - errParamFinalizeLen = errors.New("invalid number of elements in evaluation") - errParamInputEqualLen = errors.New("input lengths are not equal") - errParamNoPubKey = errors.New("missing public key") - - errEvalSerDeMin = errors.New("evaluation : insufficient header length") - errEvalSerDeElements = errors.New("evaluation : insufficient number of evaluations") - errEvalSerDeProofLen = errors.New("evaluation : invalid length of proof") - - errStateDiffInput = errors.New("state : different number of input and blinded values") - errStateDiffBlind = errors.New("state : got blinded elements but different number of blinds") - errStateNoPubKey = errors.New("state in verifiable mode but no server public key") - - errProofFailed = errors.New("proof fails") - errZeroScalar = errors.New("inversion led to zero scalar") -) diff --git a/evaluation.go b/evaluation.go deleted file mode 100644 index 9efe734..0000000 --- a/evaluation.go +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -package voprf - -import ( - "fmt" - - group "github.com/bytemare/crypto" -) - -// Evaluation holds the serialized evaluated elements and serialized proof. -type Evaluation struct { - // Elements represents the unique serialization of Elements - Elements [][]byte `json:"e"` - - // Proofs - ProofC []byte `json:"c,omitempty"` - ProofS []byte `json:"s,omitempty"` -} - -// Serialize returns a compact encoding of the Evaluation. -func (e *Evaluation) Serialize() []byte { - ne := len(e.Elements) - lp := len(e.Elements[0]) - s := make([]byte, 0, 2+2+ne*lp) - s = append(s, i2osp2(ne)...) - s = append(s, i2osp2(lp)...) - - for _, el := range e.Elements { - s = append(s, el...) - } - - if e.ProofC != nil && e.ProofS != nil { - s = append(s, e.ProofC...) - s = append(s, e.ProofS...) - } - - return s -} - -// Deserialize decodes the input into the Evaluation. -func (e *Evaluation) Deserialize(input []byte) error { - length := len(input) - if length < 4 { - return errEvalSerDeMin - } - - ne := int(uint16(input[1]) | uint16(input[0])<<8) - lp := int(uint16(input[3]) | uint16(input[2])<<8) - - if length < 4+ne*lp { - return errEvalSerDeElements - } - - e.Elements = make([][]byte, ne) - - offset := 4 - - for i := 0; i < ne; i++ { - e.Elements[i] = make([]byte, lp) - copy(e.Elements[i], input[offset:offset+lp]) - offset += lp - } - - // if there's more than elements, there might be proof - if offset < length { - proof := input[offset:] - if len(proof)&1 == 1 { - return errEvalSerDeProofLen - } - - offset = len(proof) / 2 - e.ProofC = make([]byte, offset) - copy(e.ProofC, proof[:offset]) - e.ProofS = make([]byte, offset) - copy(e.ProofS, proof[offset:]) - } - - return nil -} - -// deserialize returns a structure with the internal representations of the evaluated elements and proofs. -func (e *Evaluation) deserialize(g group.Group) (*evaluation, error) { - eval := &evaluation{ - proofC: nil, - proofS: nil, - elements: make([]*group.Element, len(e.Elements)), - } - - for i, el := range e.Elements { - elm := g.NewElement() - if err := elm.Decode(el); err != nil { - return nil, fmt.Errorf("could not decode element : %w", err) - } - - eval.elements[i] = elm - } - - if len(e.ProofC) != 0 { - eval.proofC = g.NewScalar() - if err := eval.proofC.Decode(e.ProofC); err != nil { - return nil, fmt.Errorf("invalid c scalar proof: %w", err) - } - } - - if len(e.ProofS) != 0 { - eval.proofS = g.NewScalar() - if err := eval.proofS.Decode(e.ProofS); err != nil { - return nil, fmt.Errorf("invalid c scalar proof: %w", err) - } - } - - return eval, nil -} - -// evaluation holds the evaluated elements and proofs in their internal representations. -type evaluation struct { - proofC *group.Scalar - proofS *group.Scalar - elements []*group.Element -} - -// serialize the components of the evaluation into byte arrays to be exposed in API. -func (e *evaluation) serialize() *Evaluation { - ev := &Evaluation{ - Elements: make([][]byte, len(e.elements)), - ProofC: nil, - ProofS: nil, - } - - for i, el := range e.elements { - ev.Elements[i] = el.Encode() - } - - if e.proofC != nil { - ev.ProofC = e.proofC.Encode() - } - - if e.proofS != nil { - ev.ProofS = e.proofS.Encode() - } - - return ev -} diff --git a/examples_test.go b/examples_test.go index dd7990e..39e9827 100644 --- a/examples_test.go +++ b/examples_test.go @@ -12,141 +12,201 @@ import ( "encoding/hex" "fmt" - "github.com/bytemare/voprf" -) + group "github.com/bytemare/crypto" -func exchangeWithServer(blinded []byte, verifiable bool) []byte { - var server *voprf.Server - var err error - privateKey, _ := hex.DecodeString("8132542d5ed08594e7522b5eac6bee38bab5868996c25a3fd2a7739be1856b04") + oprf "github.com/bytemare/voprf" + "github.com/bytemare/voprf/voprf" +) - if verifiable { - server, err = voprf.Ristretto255Sha512.Server(voprf.VOPRF, privateKey) - if err != nil { - panic(err) - } - } else { - server, err = voprf.Ristretto255Sha512.Server(voprf.OPRF, privateKey) - if err != nil { - panic(err) - } +func exchangeWithOPRFServer(blinded *group.Element) []byte { + // Let's say this is the private key the server uses. + encodedPrivateKey, _ := hex.DecodeString("8132542d5ed08594e7522b5eac6bee38bab5868996c25a3fd2a7739be1856b04") + privateKey, err := oprf.Ristretto255Sha512.DecodeScalar(encodedPrivateKey) + if err != nil { + panic(err) } - evaluation, err := server.Evaluate(blinded, nil) + return oprf.Evaluate(privateKey, blinded).Encode() +} + +func exchangeWithVOPRFServer(ciphersuite oprf.Ciphersuite, blinded *group.Element) []byte { + // Let's say this is the private key the server uses. + encodedPrivateKey, _ := hex.DecodeString("8132542d5ed08594e7522b5eac6bee38bab5868996c25a3fd2a7739be1856b04") + privateKey, err := oprf.Ristretto255Sha512.DecodeScalar(encodedPrivateKey) if err != nil { panic(err) } - ev := evaluation.Serialize() + publicKey := ciphersuite.Group().Base().Multiply(privateKey) - return ev + server := voprf.NewServer(oprf.Ristretto255Sha512) + if err = server.SetKeyPair(privateKey, publicKey); err != nil { + if err != nil { + panic(err) + } + } + + return server.Evaluate(blinded).Serialize() } // This shows you how to set up and run the base OPRF client. -func Example_client() { +func Example_oprf_client() { + // Your configuration. + ciphersuite := oprf.Ristretto255Sha512 input := []byte("input") - // Set up a new client. Not indicating a server public key indicates we don't use the verifiable mode. - client, err := voprf.Ristretto255Sha512.Client(voprf.OPRF, nil) + // Set up a new client. + client := ciphersuite.Client() + + // The following is optional and only useful in very rare edge-cases (e.g. tests), where you want to use a specific + // blind. Note that blinds are supposed to be secret and ephemeral. + // In normal circumstances, you don't need to set your blinds. + encodedBlind, _ := hex.DecodeString("39b5cfe207bfa50cf4ae02becc06332ae44f746514139896faef99b64cd7d20c") + + blind, err := ciphersuite.DecodeScalar(encodedBlind) if err != nil { panic(err) } - // The client blinds the initial input, and sends this to the server. - blinded := client.Blind(input, nil) - fmt.Printf("Send these %d bytes to the server.\n", len(blinded)) - - // Exchange with the server is not covered in this example. Let's say the server sends the following serialized - // evaluation. - evaluation, _ := hex.DecodeString("00010020b4d261d982c6edd2fea53e8a39c1df6393f23cb9d1b4768891ec2f43b8d8e831") + client.SetBlind(blind) - // The client needs to decode the evaluation to finalize the process. - eval := new(voprf.Evaluation) - if err = eval.Deserialize(evaluation); err != nil { + // The client blinds the initial input, and sends this to the server. + blinded := client.Blind(input) + fmt.Printf( + "Send these %d encoded bytes to the server: %s\n", + len(blinded.Encode()), + hex.EncodeToString(blinded.Encode()), + ) + + // For the purpose of this example, the following simulates and exchange with the server: the client sends the + // blinded element, and the server sends back the evaluated element. + encodedEvaluation := exchangeWithOPRFServer(blinded) + + // If a byte array was received, client needs to decode the encoded evaluation to finalize the process. + evaluated, err := ciphersuite.DecodeElement(encodedEvaluation) + if err != nil { panic(err) } // The client finalizes the protocol execution by reverting the blinding and hashing the protocol transcript. - output, err := client.Finalize(eval, nil) + output := client.Finalize(evaluated) if output == nil || err != nil { panic(err) } - // Output:Send these 32 bytes to the server. + + fmt.Printf("OPRF client output: %s\n", hex.EncodeToString(output)) + // Output:Send these 32 encoded bytes to the server: aee258da2b3f9c5616f19fb84f40f04278539253a02789490c0a5b380dc8eb39 + // OPRF client output: 08b80bebe6c6aa40143f46c0892f930b98efa122f89a16e62471a05c905e9ffa9be7f8e5633bb95edd28e96b113d1d0fee66b4e6a83942685a36876a6e37550b +} + +// This shows you how to set up and run the base OPRF server evaluation. +func Example_oprf_server() { + // Your configuration. + ciphersuite := oprf.Ristretto255Sha512 + encodedPrivateKey, _ := hex.DecodeString("8132542d5ed08594e7522b5eac6bee38bab5868996c25a3fd2a7739be1856b04") + + // Let's decode the server keys. + privateKey, err := oprf.Ristretto255Sha512.DecodeScalar(encodedPrivateKey) + if err != nil { + panic(err) + } + + // We suppose the client sends this blinded element. + encodedBlindedElement, _ := hex.DecodeString("7eaf3d7cbe43d54637274342ce53578b2aba836f297f4f07997a6e1dced1c058") + + // We need to decode the client provided element. + blinded, err := ciphersuite.DecodeElement(encodedBlindedElement) + if err != nil { + panic(err) + } + + // No need to set up a server, as the operation is very simple. + evaluation := oprf.Evaluate(privateKey, blinded) + + // The server encodes the evaluation, and sends it to the client. + encodedEvaluation := evaluation.Encode() + fmt.Printf("Encoded evaluation: %s", hex.EncodeToString(encodedEvaluation)) + // Output:Encoded evaluation: 8c2466a064a1eab64b226aa5a19df2115383693fe4ef260976e18949d28e9050 } // This shows you how to set up and run the Verifiable OPRF client. -func Example_verifiableClient() { - ciphersuite := voprf.Ristretto255Sha512 +func Example_voprf_client() { + // Your configuration. + ciphersuite := oprf.Ristretto255Sha512 input := []byte("input") - serverPubKey, _ := hex.DecodeString("066c39841db2ca3c2e83e251e71b619013674149692ca2ab41d1b33a1a4fff38") + serverPublicKeyHex := "066c39841db2ca3c2e83e251e71b619013674149692ca2ab41d1b33a1a4fff38" + + // To initiate the client we the server's public key. + encodedServerPubKey, _ := hex.DecodeString(serverPublicKeyHex) + serverPublicKey, _ := ciphersuite.DecodeElement(encodedServerPubKey) // Instantiate a new client with the preprocessed values. - client, err := ciphersuite.Client(voprf.VOPRF, serverPubKey) + client, err := voprf.NewClient(ciphersuite, serverPublicKey) if err != nil { panic(err) } // The client blinds the initial input, and sends this to the server. - blinded := client.Blind(input, nil) + blinded := client.Blind(input) // Exchange with the server is not covered here. The following call is to mock an exchange with a server. - evaluation := exchangeWithServer(blinded, true) + evaluation := exchangeWithVOPRFServer(ciphersuite, blinded) // The client needs to decode the evaluation to finalize the process. eval := new(voprf.Evaluation) - if err := eval.Deserialize(evaluation); err != nil { + eval.SetCiphersuite(ciphersuite) + if err = eval.Deserialize(evaluation); err != nil { panic(err) } // The client finalizes the protocol execution by reverting the blinding and hashing the protocol transcript. // If proof verification fails, an error is returned. - output, err := client.Finalize(eval, nil) + output, err := client.Finalize(eval) if output == nil || err != nil { panic(err) } // Output: } -// This shows you how to set up and run the base OPRF server. -func Example_server() { - // We suppose the client sends this blinded element. - blinded, _ := hex.DecodeString("7eaf3d7cbe43d54637274342ce53578b2aba836f297f4f07997a6e1dced1c058") - - // Set up a new server. A private key is automatically created if none is given. - server, err := voprf.Ristretto255Sha512.Server(voprf.OPRF, nil) +// This shows you how to set up and run the Verifiable OPRF server. +func Example_voprf_server() { + // Your configuration. + ciphersuite := oprf.Ristretto255Sha512 + encodedPrivateKey, _ := hex.DecodeString("8132542d5ed08594e7522b5eac6bee38bab5868996c25a3fd2a7739be1856b04") + encodedPublicKey, _ := hex.DecodeString("066c39841db2ca3c2e83e251e71b619013674149692ca2ab41d1b33a1a4fff38") + + // Let's decode the server keys. + privateKey, err := oprf.Ristretto255Sha512.DecodeScalar(encodedPrivateKey) if err != nil { panic(err) } - // The server evaluates the blinded input. - evaluation, err := server.Evaluate(blinded, nil) + publicKey, err := oprf.Ristretto255Sha512.DecodeElement(encodedPublicKey) if err != nil { panic(err) } - // The server encodes the evaluation, and sends it to the client. - _ = evaluation.Serialize() - // Output: -} + // Set up a new server. If no info is provided, the VOPRF is used. If you want to use the POPRF mode, + // you must provide the POPRF info here as the additional argument. + server := voprf.NewServer(ciphersuite) -// This shows you how to set up and run the Verifiable OPRF server. -func Example_verifiableServer() { - privateKey, _ := hex.DecodeString("8132542d5ed08594e7522b5eac6bee38bab5868996c25a3fd2a7739be1856b04") + if err = server.SetKeyPair(privateKey, publicKey); err != nil { + if err != nil { + panic(err) + } + } - // We suppose the client sends this blinded element. - blinded, _ := hex.DecodeString("7eaf3d7cbe43d54637274342ce53578b2aba836f297f4f07997a6e1dced1c058") + // Let's suppose the client sends this blinded element. + encodedBlindedElement, _ := hex.DecodeString("7eaf3d7cbe43d54637274342ce53578b2aba836f297f4f07997a6e1dced1c058") - // Set up a new server. - server, err := voprf.Ristretto255Sha512.Server(voprf.VOPRF, privateKey) + // We need to decode the client provided element. + blinded, err := ciphersuite.DecodeElement(encodedBlindedElement) if err != nil { panic(err) } - // The server evaluates the blinded input. Proofs are embedded in the evaluation. - evaluation, err := server.Evaluate(blinded, nil) - if err != nil { - panic(err) - } + // The server evaluates the blinded input. + evaluation := server.Evaluate(blinded) // The server encodes the evaluation, and sends it to the client. _ = evaluation.Serialize() diff --git a/go.mod b/go.mod index 24e6043..c44238a 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.2 require ( github.com/bytemare/crypto v0.6.0 github.com/bytemare/hash v0.3.0 + github.com/bytemare/secret-sharing v0.1.1-0.20240521200815-77e8c3ad0689 ) require ( @@ -13,6 +14,6 @@ require ( github.com/bytemare/hash2curve v0.3.0 // indirect github.com/bytemare/secp256k1 v0.1.2 // indirect github.com/gtank/ristretto255 v0.1.2 // indirect - golang.org/x/crypto v0.22.0 // indirect - golang.org/x/sys v0.19.0 // indirect + golang.org/x/crypto v0.23.0 // indirect + golang.org/x/sys v0.20.0 // indirect ) diff --git a/go.sum b/go.sum index 110b1a6..3d802d6 100644 --- a/go.sum +++ b/go.sum @@ -10,9 +10,11 @@ github.com/bytemare/hash2curve v0.3.0 h1:41Npcbc+u/E252A5aCMtxDcz7JPkkX1QzShneTF github.com/bytemare/hash2curve v0.3.0/go.mod h1:itj45U8uqvCtWC0eCswIHVHswXcEHkpFui7gfJdPSfQ= github.com/bytemare/secp256k1 v0.1.2 h1:aM+p/+0y1h0SZWqS/yzjGPzffVFubJvwLjUgodFEWOo= github.com/bytemare/secp256k1 v0.1.2/go.mod h1:Pxb9miDs8PTt5mOktvvXiRflvLxI1wdxbXrc6IYsaho= +github.com/bytemare/secret-sharing v0.1.1-0.20240521200815-77e8c3ad0689 h1:4KOIuk4w138AgaLtuad8PihkZGJSwnef+hbU/19RVkM= +github.com/bytemare/secret-sharing v0.1.1-0.20240521200815-77e8c3ad0689/go.mod h1:P8YRt2irx5PdiL7EwJDBkVfS3EgLcbjMFV5gqPFf3GY= github.com/gtank/ristretto255 v0.1.2 h1:JEqUCPA1NvLq5DwYtuzigd7ss8fwbYay9fi4/5uMzcc= github.com/gtank/ristretto255 v0.1.2/go.mod h1:Ph5OpO6c7xKUGROZfWVLiJf9icMDwUeIvY4OmlYW69o= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= -golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= -golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/internal/client.go b/internal/client.go new file mode 100644 index 0000000..88b851d --- /dev/null +++ b/internal/client.go @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +package internal + +import ( + "errors" + "slices" + + group "github.com/bytemare/crypto" +) + +var ( + errBatchNoElements = errors.New("no evaluated elements provided to Finalize()") + errBatchDifferentSize = errors.New("number of evaluations is different thant number of previously blinded inputs") +) + +// A Client holds the core functionalities for all OPRF, TOPRF, VOPRF, and POPRF. +type Client struct { + // Core abstracts configuration dependent operations. + *Core + + // Inputs registry: the inputs are necessary in blinding and finalizing. + Inputs [][]byte + + // Blinds registry: the blinds are necessary in blinding and finalizing. + Blinds []*group.Scalar +} + +// NewClient loads the configuration for a new client. The info argument should only be set by the caller in the POPRF +// mode. +func NewClient(mode Mode, g group.Group) *Client { + return &Client{ + Core: LoadConfiguration(g, mode), + Inputs: make([][]byte, 1), + Blinds: make([]*group.Scalar, 1), + } +} + +// Size returns the length of the input and blind registers in its current state. +func (c *Client) Size() int { + return len(c.Inputs) +} + +// UpdateStateCapacity increases the internal input and blind registers to n, if necessary. If n is smaller than the +// current capacity, the buffers are unchanged. +func (c *Client) UpdateStateCapacity(n int) { + if n <= cap(c.Inputs) { + return + } + + d := n - cap(c.Inputs) + + c.Inputs = slices.Grow(c.Inputs, d) + c.Inputs = append(c.Inputs, make([][]byte, d)...) + c.Blinds = slices.Grow(c.Blinds, d) + c.Blinds = append(c.Blinds, make([]*group.Scalar, d)...) +} + +// SetBlind sets a single blinding scalar at position index in the internal register. +func (c *Client) SetBlind(index int, blind *group.Scalar) *Client { + c.Blinds[index] = c.Group.NewScalar().Set(blind) + return c +} + +// Blind uses the blinding scalar at position index in the internal register to blind the input, and return the blinded +// input. +func (c *Client) Blind(index int, input []byte) *group.Element { + // register input and blind + c.Inputs[index] = make([]byte, len(input)) + copy(c.Inputs[index], input) + + if c.Blinds[index] == nil { + c.Blinds[index] = c.Core.Group.NewScalar().Random() + } + + // blind input + p := c.HashToGroup(input) + if p.IsIdentity() { + panic(errInvalidInput) + } + + return p.Multiply(c.Blinds[index]) +} + +// Unblind uses the blinding scalar at position index in the internal register to unblind the evaluated element, and +// return the unblinded evaluation. +func (c *Client) Unblind(index int, evaluated *group.Element) *group.Element { + inv := c.Blinds[index].Copy().Invert() + return evaluated.Copy().Multiply(inv) +} + +// Finalize finalizes the client's xOPRF execution. It takes a server evaluated element and the position in the internal +// blind register of the blind used in the blinding phase and returns the xOPRF output. The optional info argument must +// only be provided when using the POPRF mode. +func (c *Client) Finalize(index int, evaluated *group.Element, info ...byte) []byte { + unblinded := c.Unblind(index, evaluated) + return c.HashTranscript(c.Inputs[index], unblinded.Encode(), info) +} + +// FinalizeBatch unblinds the evaluated elements and returns the corresponding protocol outputs. The optional info +// argument must only be provided when using the POPRF mode. +func (c *Client) FinalizeBatch(evaluated []*group.Element, info ...byte) ([][]byte, error) { + if len(evaluated) == 0 { + return nil, errBatchNoElements + } + + if len(evaluated) != c.Size() { + return nil, errBatchDifferentSize + } + + out := make([][]byte, len(evaluated)) + + for i, e := range evaluated { + out[i] = c.Finalize(i, e, info...) + } + + return out, nil +} diff --git a/internal/configuration.go b/internal/configuration.go new file mode 100644 index 0000000..e50a6db --- /dev/null +++ b/internal/configuration.go @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +// Package internal handles all core xOPRF functionalities. +package internal + +import ( + "errors" + "fmt" + + group "github.com/bytemare/crypto" + "github.com/bytemare/hash" +) + +// Mode distinguishes between the OPRF base mode and the VOPRF mode. +type Mode byte + +const ( + // OPRF identifies the base mode. + OPRF Mode = iota + + // VOPRF identifies the verifiable mode. + VOPRF + + // POPRF identifies the partially-oblivious mode. + POPRF +) + +const ( + // Version is a string explicitly stating the Version name. + Version = "OPRFV1" + + hash2groupDSTPrefix = "HashToGroup-" + hash2scalarDSTPrefix = "HashToScalar-" + dstSeed = "Seed-" + + contextStringPrefix = Version + "-" + + dstFinalize = "Finalize" + dstInfo = "Info" + + // deriveKeyPairDST is the DST prefix for the DeriveKeyPair function. + deriveKeyPairDST = "DeriveKeyPair" +) + +var ( + // CiphersuiteIdentifier maps a group to its [RFC9497](https://datatracker.ietf.org/doc/rfc9497) compliant + // identifier. + CiphersuiteIdentifier = map[group.Group]string{ + group.Ristretto255Sha512: "ristretto255-SHA512", + group.P256Sha256: "P256-SHA256", + group.P384Sha384: "P384-SHA384", + group.P521Sha512: "P521-SHA512", + group.Secp256k1: "secp256k1-SHA256", + } + + errInvalidInput = errors.New( + "invalid input - OPRF input deterministically maps to the group identity element", + ) + errInvalidPOPRFPrivateKey = errors.New( + "invalid input - POPRF private key tweaking yields the zero scalar", + ) + errInvalidPOPRFPubKey = errors.New( + "invalid input - POPRF public key tweaking yields the group identity element", + ) +) + +// A Core holds the cryptographic configuration and methods used for xOPRF operations. +type Core struct { + Hash hash.Hasher + dstH2gDST []byte + dstH2sDST []byte + Group group.Group + Mode Mode +} + +// ContextString builds the xOPRF constant string used for domain separation tags. +func ContextString(mode Mode, name string) []byte { + return []byte(contextStringPrefix + string(mode) + "-" + name) +} + +func makeCore(g group.Group, h hash.Hash, mode Mode) *Core { + ctx := ContextString(mode, CiphersuiteIdentifier[g]) + + return &Core{ + Group: g, + Hash: h.New(), + Mode: mode, + dstH2gDST: Dst(hash2groupDSTPrefix, ctx), + dstH2sDST: Dst(hash2scalarDSTPrefix, ctx), + } +} + +// LoadConfiguration returns a core configuration given the ciphersuite and mode. The info argument should only be +// provided in POPRF mode. +func LoadConfiguration(g group.Group, mode Mode) *Core { + switch g { + case group.Ristretto255Sha512: + return makeCore(group.Ristretto255Sha512, hash.SHA512, mode) + case group.P256Sha256: + return makeCore(group.P256Sha256, hash.SHA256, mode) + case group.P384Sha384: + return makeCore(group.P384Sha384, hash.SHA384, mode) + case group.P521Sha512: + return makeCore(group.P521Sha512, hash.SHA512, mode) + case group.Secp256k1: + return makeCore(group.Secp256k1, hash.SHA256, mode) + default: + panic(fmt.Sprintf("invalid OPRF dependency - Group: %v", g)) + } +} + +// DeriveKeyPair derives a private-public key pair given a secret seed and instance specific info. +func (c Core) DeriveKeyPair(seed, info []byte) (*group.Scalar, *group.Element) { + dst := concatenate([]byte(deriveKeyPairDST), ContextString(c.Mode, CiphersuiteIdentifier[c.Group])) + deriveInput := concatenate(seed, lengthPrefixEncode(info)) + + var counter uint8 + var sk *group.Scalar + + for sk == nil || sk.IsZero() { + if counter > 255 { + panic("impossible to generate non-zero scalar") + } + + sk = c.Group.HashToScalar(concatenate(deriveInput, []byte{counter}), dst) + counter++ + } + + return sk, c.Group.Base().Multiply(sk) +} + +// HashTranscript hashes a xOPRF run's transcript (without the blind) to produce the protocol's output. +func (c Core) HashTranscript(input, unblinded, poprfInfo []byte) []byte { + encInput := lengthPrefixEncode(input) + encElement := lengthPrefixEncode(unblinded) + encDST := []byte(dstFinalize) + + var h []byte + + if len(poprfInfo) != 0 { // POPRF + encInfo := lengthPrefixEncode(poprfInfo) + h = c.Hash.Hash(0, encInput, encInfo, encElement, encDST) + } else { // OPRF and VOPRF + h = c.Hash.Hash(0, encInput, encElement, encDST) + } + + return h +} + +// HashToScalar maps the input data to a scalar. +func (c Core) HashToScalar(data []byte) *group.Scalar { + return c.Group.HashToScalar(data, c.dstH2sDST) +} + +// HashToGroup maps the input data to an element of the Group. +func (c Core) HashToGroup(data []byte) *group.Element { + return c.Group.HashToGroup(data, c.dstH2gDST) +} diff --git a/internal/nizk.go b/internal/nizk.go new file mode 100644 index 0000000..54567aa --- /dev/null +++ b/internal/nizk.go @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +package internal + +import ( + "errors" + + group "github.com/bytemare/crypto" +) + +const ( + dstComposite = "Composite" + dstChallenge = "Challenge" +) + +var errProofFailed = errors.New("proof fails") + +// Verifiable enables VOPRF and POPRF functions over OPRF operations. +type Verifiable struct { + *Core + POPRFInfo []byte + seedDST []byte +} + +// NewVerifiable returns a core configuration for VOPRF and POPRF given the ciphersuite and mode. +// The info argument should only be provided in POPRF mode. +func NewVerifiable(c *Core, info []byte) *Verifiable { + if len(info) != 0 && c.Mode != POPRF { + panic("internal error: POPRF info provided but POPRF mode not set") + } + + ctx := ContextString(c.Mode, CiphersuiteIdentifier[c.Group]) + + return &Verifiable{ + Core: c, + POPRFInfo: info, + seedDST: Dst(dstSeed, ctx), + } +} + +func (v Verifiable) challenge(encPks []byte, a0, a1, a2, a3 *group.Element) *group.Scalar { + encA0 := lengthPrefixEncode(a0.Encode()) + encA1 := lengthPrefixEncode(a1.Encode()) + encA2 := lengthPrefixEncode(a2.Encode()) + encA3 := lengthPrefixEncode(a3.Encode()) + encDST := []byte(dstChallenge) + input := concatenate(encPks, encA0, encA1, encA2, encA3, encDST) + + return v.HashToScalar(input) +} + +func (v Verifiable) pTag(info []byte) *group.Scalar { + framedInfo := make([]byte, 0, len(dstInfo)+2+len(info)) // dstInfo + lengthPrefixEncode(info) + framedInfo = append(framedInfo, dstInfo...) + framedInfo = append(framedInfo, lengthPrefixEncode(info)...) + + return v.HashToScalar(framedInfo) +} + +// TweakPrivateKey tweaks the input scalar for use in the POPRF setting. +func (v Verifiable) TweakPrivateKey(privateKey *group.Scalar) (*group.Scalar, *group.Scalar) { + context := v.pTag(v.POPRFInfo) + t := privateKey.Copy().Add(context) + scalar := t.Copy().Invert() + + if scalar.IsZero() { + panic(errInvalidPOPRFPrivateKey) + } + + return scalar, t +} + +// TweakPublicKey tweaks the input element for use in the POPRF setting. +func (v Verifiable) TweakPublicKey(pubKey *group.Element) *group.Element { + m := v.pTag(v.POPRFInfo) + + t := v.Group.Base().Multiply(m).Add(pubKey) + if t.IsIdentity() { + panic(errInvalidPOPRFPubKey) + } + + return t +} + +// GenerateProof produces a non-interactive zero-knowledge (NIZK) proof on the evaluated elements. +func (v Verifiable) GenerateProof( + random, k *group.Scalar, + pk *group.Element, + cs, ds []*group.Element, +) (*group.Scalar, *group.Scalar) { + encPk := lengthPrefixEncode(pk.Encode()) + a0, a1 := v.computeComposites(k, encPk, cs, ds) + + a2 := v.Group.Base().Multiply(random) + a3 := a0.Copy().Multiply(random) + + proofC := v.challenge(encPk, a0, a1, a2, a3) + proofS := random.Subtract(proofC.Copy().Multiply(k)) + + return proofC, proofS +} + +// VerifyProof verifies the non-interactive zero-knowledge (NIZK) proof on the evaluated elements produced by +// GenerateProof. +func (v Verifiable) VerifyProof(proofC, proofS *group.Scalar, pubKey *group.Element, cs, ds []*group.Element) error { + encGk := lengthPrefixEncode(pubKey.Encode()) + a0, a1 := v.computeComposites(nil, encGk, cs, ds) + + ap := pubKey.Copy().Multiply(proofC) + a2 := v.Group.Base().Multiply(proofS).Add(ap) + + bm := a0.Copy().Multiply(proofS) + bz := a1.Copy().Multiply(proofC) + a3 := bm.Add(bz) + expectedC := v.challenge(encGk, a0, a1, a2, a3) + + if !ctEqual(expectedC.Encode(), proofC.Encode()) { + return errProofFailed + } + + return nil +} + +func (v Verifiable) ccScalar(encSeed []byte, index int, ci, di *group.Element) *group.Scalar { + input := concatenate(encSeed, I2osp2(index), + lengthPrefixEncode(ci.Encode()), + lengthPrefixEncode(di.Encode()), + []byte(dstComposite)) + + return v.HashToScalar(input) +} + +func (v Verifiable) computeCompositesFast( + k *group.Scalar, + encSeed []byte, + cs, ds []*group.Element, +) (*group.Element, *group.Element) { + m := v.Group.NewElement().Identity() + + for i, ci := range cs { + di := v.ccScalar(encSeed, i, ci, ds[i]) + m = ci.Copy().Multiply(di).Add(m) + } + + return m, m.Copy().Multiply(k) +} + +func (v Verifiable) computeCompositesClient(encSeed []byte, cs, ds []*group.Element) (*group.Element, *group.Element) { + m := v.Group.NewElement().Identity() + z := v.Group.NewElement().Identity() + + for i, ci := range cs { + di := v.ccScalar(encSeed, i, ci, ds[i]) + m = ci.Copy().Multiply(di).Add(m) + z = ds[i].Copy().Multiply(di).Add(z) + } + + return m, z +} + +func (v Verifiable) computeComposites( + k *group.Scalar, + encGk []byte, + cs, ds []*group.Element, +) (*group.Element, *group.Element) { + encSeedDST := lengthPrefixEncode(v.seedDST) + + // build seed + seed := v.Hash.Hash(0, encGk, encSeedDST) + encSeed := lengthPrefixEncode(seed) + + // This means where calling from the server, and can optimize computation of Z, since Zi = sks * Mi + if k != nil { + return v.computeCompositesFast(k, encSeed, cs, ds) + } + + return v.computeCompositesClient(encSeed, cs, ds) +} diff --git a/utils.go b/internal/utils.go similarity index 72% rename from utils.go rename to internal/utils.go index a08cb6d..855422a 100644 --- a/utils.go +++ b/internal/utils.go @@ -6,7 +6,7 @@ // LICENSE file in the root directory of this source tree or at // https://spdx.org/licenses/MIT.html -package voprf +package internal import ( "crypto/subtle" @@ -16,14 +16,15 @@ import ( ) // KeyPair assembles a VOPRF key pair. The SecretKey can be used as the evaluation key for -// the group identified by Ciphersuite. +// the Group identified by Ciphersuite. type KeyPair struct { PublicKey *group.Element SecretKey *group.Scalar - Ciphersuite Ciphersuite + Ciphersuite group.Group } -func i2osp2(value int) []byte { +// I2osp2 encodes the integer to a 2-byte byte string. +func I2osp2(value int) []byte { out := make([]byte, 2) binary.BigEndian.PutUint16(out, uint16(value)) @@ -31,7 +32,7 @@ func i2osp2(value int) []byte { } func lengthPrefixEncode(input []byte) []byte { - return append(i2osp2(len(input)), input...) + return append(I2osp2(len(input)), input...) } func ctEqual(a, b []byte) bool { @@ -61,11 +62,7 @@ func concatenate(input ...[]byte) []byte { return buf } -func dst(prefix string, contextString []byte) []byte { - p := []byte(prefix) - t := make([]byte, 0, len(p)+len(contextString)) - t = append(t, p...) - t = append(t, contextString...) - - return t +// Dst returns the domain separation tag, i.e. the concatenation of the input. +func Dst(prefix string, contextString []byte) []byte { + return []byte(prefix + string(contextString)) } diff --git a/oprf.go b/oprf.go index 0c40301..0a959e6 100644 --- a/oprf.go +++ b/oprf.go @@ -6,312 +6,127 @@ // LICENSE file in the root directory of this source tree or at // https://spdx.org/licenses/MIT.html +// Package voprf implements RFC9497 and provides abstracted access to Oblivious Pseudorandom Functions (OPRF) and +// Threshold Oblivious Pseudorandom Functions (TOPRF) using Elliptic Curve Prime Order Groups (EC-OPRF). +// For VOPRF and POPRF use the github.com/bytemare/oprf/voprf package. package voprf import ( - "fmt" - group "github.com/bytemare/crypto" - "github.com/bytemare/hash" -) - -// Mode distinguishes between the OPRF base mode and the VOPRF mode. -type Mode byte - -const ( - // OPRF identifies the base mode. - OPRF Mode = iota - // VOPRF identifies the verifiable mode. - VOPRF - - // POPRF identifies the partially-oblivious mode. - POPRF + "github.com/bytemare/voprf/internal" ) -// Ciphersuite of the OPRF compatible cipher suite to be used. -type Ciphersuite string +// Ciphersuite of the xOPRF compatible cipher suite to be used. +type Ciphersuite byte const ( - // Ristretto255Sha512 is the OPRF cipher suite of the Ristretto255 group and SHA-512. - Ristretto255Sha512 Ciphersuite = "ristretto255-SHA512" - - // Decaf448Sha512 is the OPRF cipher suite of the Decaf448 group and SHA-512. - // decaf448Sha512 Ciphersuite = "decaf448-SHAKE256". - - // P256Sha256 is the OPRF cipher suite of the NIST P-256 group and SHA-256. - P256Sha256 Ciphersuite = "P256-SHA256" - - // P384Sha384 is the OPRF cipher suite of the NIST P-384 group and SHA-384. - P384Sha384 Ciphersuite = "P384-SHA384" + // Ristretto255Sha512 identifies the Ristretto255 group and SHA-512. + Ristretto255Sha512 = Ciphersuite(group.Ristretto255Sha512) - // P521Sha512 is the OPRF cipher suite of the NIST P-512 group and SHA-512. - P521Sha512 Ciphersuite = "P521-SHA512" + // decaf448Shake256 identifies the Decaf448 group and Shake-256. Not supported. + // decaf448Shake256 = 2. - // Secp256k1 is the OPRF cipher suite of the SECp256k1 group and SHA-256. - Secp256k1 Ciphersuite = "secp256k1-SHA256" + // P256Sha256 identifies the NIST P-256 group and SHA-256. + P256Sha256 = Ciphersuite(group.P256Sha256) - nbIDs = 5 + // P384Sha384 identifies the NIST P-384 group and SHA-384. + P384Sha384 = Ciphersuite(group.P384Sha384) - // Version is a string explicitly stating the Version name. - Version = "OPRFV1" + // P521Sha512 identifies the NIST P-512 group and SHA-512. + P521Sha512 = Ciphersuite(group.P521Sha512) - // deriveKeyPairDST is the DST prefix for the DeriveKeyPair function. - deriveKeyPairDST = "DeriveKeyPair" - - // hash2groupDSTPrefix is the DST prefix to use for HashToGroup operations. - hash2groupDSTPrefix = "HashToGroup-" - - // hash2scalarDSTPrefix is the DST prefix to use for HashToScalar operations. - hash2scalarDSTPrefix = "HashToScalar-" -) - -var ( - groups = make(map[Ciphersuite]group.Group, nbIDs) - hashes = make(map[Ciphersuite]hash.Hash, nbIDs) + // Secp256k1 identifies the SECp256k1 group and SHA-256. + Secp256k1 = Ciphersuite(group.Secp256k1) ) -func (c Ciphersuite) new(mode Mode) *oprf { - return &oprf{ - hash: hashes[c].New(), - contextString: contextString(mode, c), - ciphersuite: c, - mode: mode, - group: groups[c], - } +// FromGroup returns a Ciphersuite given a Group. +func FromGroup(g group.Group) Ciphersuite { + return Ciphersuite(g) } -// Available returns whether the Ciphersuite is registered and available for usage. -func (c Ciphersuite) Available() bool { - // Check for invalid identifiers - switch c { - case Ristretto255Sha512, P256Sha256, P384Sha384, P521Sha512, Secp256k1: - break - default: - return false - } - - // Check for unregistered groups and hashes - if _, ok := groups[c]; !ok { - return false - } - - if _, ok := hashes[c]; !ok { - return false - } - - return true -} - -// Group returns the group identifier used in the cipher suite. +// Group returns the elliptic curve prime-order group of the ciphersuite. func (c Ciphersuite) Group() group.Group { - return groups[c] -} - -// Hash returns the hash function identifier used in the cipher suite. -func (c Ciphersuite) Hash() hash.Hash { - return hashes[c] -} - -// FromGroup returns a (V)OPRF Ciphersuite given a Group Ciphersuite. -func FromGroup(g group.Group) (Ciphersuite, error) { - for k, v := range groups { - if v == g { - return k, nil - } - } - - return "", errParamInvalidID -} - -// KeyGen returns a fresh KeyPair for the given cipher suite. -func (c Ciphersuite) KeyGen() *KeyPair { - sk := c.Group().NewScalar().Random() - pk := c.Group().Base().Multiply(sk) - - return &KeyPair{ - Ciphersuite: c, - PublicKey: pk, - SecretKey: sk, - } + return group.Group(c) } -// DeriveKeyPair deterministically generates a private and public key pair from input seed. -func (c Ciphersuite) DeriveKeyPair(mode Mode, seed, info []byte) *KeyPair { - dst := concatenate([]byte(deriveKeyPairDST), contextString(mode, c)) - deriveInput := concatenate(seed, lengthPrefixEncode(info)) - - var counter uint8 - var sk *group.Scalar - - for sk == nil || sk.IsZero() { - if counter > 255 { - panic("impossible to generate non-zero scalar") - } - - sk = c.Group().HashToScalar(concatenate(deriveInput, []byte{counter}), dst) - counter++ - } - - return &KeyPair{ - Ciphersuite: c, - PublicKey: c.Group().Base().Multiply(sk), - SecretKey: sk, - } +// Name returns the [RFC9497](https://datatracker.ietf.org/doc/rfc9497) compliant identifier of the ciphersuite. +func (c Ciphersuite) Name() string { + return internal.CiphersuiteIdentifier[group.Group(c)] } -// Client returns a (P|V)OPRF client. For the OPRF mode, serverPublicKey should be nil, and non-nil otherwise. -func (c Ciphersuite) Client(mode Mode, serverPublicKey []byte) (*Client, error) { - if mode != OPRF && mode != VOPRF && mode != POPRF { - return nil, errParamInvalidMode - } - - client := c.client(mode) - - if mode == VOPRF || mode == POPRF { - if serverPublicKey == nil { - return nil, errParamNoPubKey - } - - if err := client.setServerPublicKey(serverPublicKey); err != nil { - return nil, err - } - } - - return client, nil +// DeriveKeyPair returns a private-public key pair for the OPRF mode, given a secret seed and instance specific info. +// VOPRF and POPRF keys must be created with server.DeriveKeyPair() in the voprf package. +// TOPRF key pairs should be created using a distributed key generation protocol. +func DeriveKeyPair(c Ciphersuite, seed, info []byte) (*group.Scalar, *group.Element) { + // We don't use this as a method to a Ciphersuite, as it might be confusing when in VOPRF or POPRF mode, which + // use the Ciphersuite identifier from this package. + return internal.LoadConfiguration(c.Group(), internal.OPRF).DeriveKeyPair(seed, info) } -// Server returns a (P|V)OPRF server instantiated with the given encoded private key. -// If privateKey is nil, a new private/public key pair is created. -func (c Ciphersuite) Server(mode Mode, privateKey []byte) (*Server, error) { - if mode != OPRF && mode != VOPRF && mode != POPRF { - return nil, errParamInvalidMode +// Client returns an OPRF client. +func (c Ciphersuite) Client() *Client { + return &Client{ + Client: internal.NewClient(internal.OPRF, group.Group(c)), } - - return c.server(mode, privateKey) } -type oprf struct { - hash hash.Hasher - ciphersuite Ciphersuite - contextString []byte - mode Mode - group group.Group +// Client is used for OPRF and TOPRF client executions. +type Client struct { + *internal.Client } -func contextString(mode Mode, ciphersuite Ciphersuite) []byte { - ctx := make([]byte, 0, len(Version)+3+len(ciphersuite.String())) - ctx = append(ctx, Version...) - ctx = append(ctx, "-"...) - ctx = append(ctx, byte(mode)) - ctx = append(ctx, "-"...) - ctx = append(ctx, ciphersuite.String()...) - - return ctx -} - -// HashToGroup maps the input data to an element of the group. -func (o *oprf) HashToGroup(data []byte) *group.Element { - return o.group.HashToGroup(data, dst(hash2groupDSTPrefix, o.contextString)) -} +// SetBlind sets one or multiple blinds in the client's blind register. This is optional, and useful if you want to +// force usage of specific blinding scalar. If no blinding scalars are set, new, random blinds will be used. +func (c *Client) SetBlind(blind ...*group.Scalar) { + c.Client.UpdateStateCapacity(len(blind)) -// HashToScalar maps the input data to a scalar. -func (o *oprf) HashToScalar(data []byte) *group.Scalar { - return o.group.HashToScalar(data, dst(hash2scalarDSTPrefix, o.contextString)) -} - -func (c Ciphersuite) client(mode Mode) *Client { - return &Client{ - tweakedKey: nil, - serverPublicKey: nil, - oprf: c.new(mode), - input: nil, - blind: nil, - blindedElement: nil, + for i, b := range blind { + c.Client.SetBlind(i, b) } } -func (c *Client) setServerPublicKey(serverPublicKey []byte) error { - if serverPublicKey == nil { // OPRF - return nil - } - - pub := c.group.NewElement() - if err := pub.Decode(serverPublicKey); err != nil { - return fmt.Errorf("invalid public key: %w", err) - } - - c.serverPublicKey = pub - - return nil +// Blind blinds the input using the first blinding scalar in the Client's register. If no blinding scalars were +// previously set, new, random blinds will be used. +func (c *Client) Blind(input []byte) *group.Element { + return c.Client.Blind(0, input) } -func (c Ciphersuite) server(mode Mode, privateKey []byte) (*Server, error) { - s := &Server{ - privateKey: nil, - publicKey: nil, - oprf: c.new(mode), - } - - if privateKey == nil { - s.KeyGen() - } else { - sk := s.group.NewScalar() - if err := sk.Decode(privateKey); err != nil { - return nil, fmt.Errorf("invalid private key: %w", err) - } +// BlindBatch blinds the given set, using either previously set blinds in the same order (if they have been set) or +// newly generated random blinds. Note that if not enough blinds were set, new, random blinds will be used as necessary. +func (c *Client) BlindBatch(inputs [][]byte) []*group.Element { + c.UpdateStateCapacity(len(inputs)) + blindedInput := make([]*group.Element, len(inputs)) - s.privateKey = sk - s.publicKey = s.group.Base().Multiply(sk) + for i, in := range inputs { + blindedInput[i] = c.Client.Blind(i, in) } - return s, nil + return blindedInput } -func (o *oprf) pTag(info []byte) *group.Scalar { - framedInfo := make([]byte, 0, len(dstInfo)+2+len(info)) // dstContext + s.contextString + lengthPrefixEncode(info) - framedInfo = append(framedInfo, dstInfo...) - framedInfo = append(framedInfo, lengthPrefixEncode(info)...) - - return o.HashToScalar(framedInfo) +// Finalize unblinds the evaluated element and returns the protocol output. +func (c *Client) Finalize(evaluated *group.Element) []byte { + return c.Client.Finalize(0, evaluated) } -func (o *oprf) hashTranscript(input, info, unblinded []byte) []byte { - encInput := lengthPrefixEncode(input) - encElement := lengthPrefixEncode(unblinded) - encDST := []byte(dstFinalize) - - var h []byte - - if info == nil { // OPRF and VOPRF - h = o.hash.Hash(0, encInput, encElement, encDST) - } else { // POPRF - encInfo := lengthPrefixEncode(info) - h = o.hash.Hash(0, encInput, encInfo, encElement, encDST) - } - - return h +// FinalizeBatch unblinds the evaluated elements and returns the corresponding protocol outputs. +func (c *Client) FinalizeBatch(evaluated []*group.Element) ([][]byte, error) { + return c.Client.FinalizeBatch(evaluated) } -// String implements the Stringer() interface for the Ciphersuite. -func (c Ciphersuite) String() string { - return string(c) +// Evaluate is the server's function to evaluate a Client provided blinded element with the server's secret key. +func Evaluate(key *group.Scalar, blinded *group.Element) *group.Element { + return blinded.Copy().Multiply(key) } -func (c Ciphersuite) register(g group.Group, h hash.Hash) { - if g.Available() && h.Available() { - groups[c] = g - hashes[c] = h - } else { - panic(fmt.Sprintf("OPRF dependencies not available - Group: %v, Hash: %v", g.Available(), h.Available())) +// EvaluateBatch is the server's function to evaluate a set of Client provided blinded elements with the +// server's secret key. +func EvaluateBatch(key *group.Scalar, blinded []*group.Element) []*group.Element { + evaluated := make([]*group.Element, len(blinded)) + for i, b := range blinded { + evaluated[i] = Evaluate(key, b) } -} -func init() { - Ristretto255Sha512.register(group.Ristretto255Sha512, hash.SHA512) - // Decaf448Sha512.register(group.Curve448Sha512, hash.SHA512). - P256Sha256.register(group.P256Sha256, hash.SHA256) - P384Sha384.register(group.P384Sha384, hash.SHA384) - P521Sha512.register(group.P521Sha512, hash.SHA512) - Secp256k1.register(group.Secp256k1, hash.SHA256) + return evaluated } diff --git a/server.go b/server.go deleted file mode 100644 index d33f7ff..0000000 --- a/server.go +++ /dev/null @@ -1,144 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -package voprf - -import ( - "fmt" - - group "github.com/bytemare/crypto" -) - -// Server holds the (V)OPRF prover data. -type Server struct { - privateKey *group.Scalar - publicKey *group.Element - *oprf -} - -// KeyGen generates and sets a new private/public key pair. -func (s *Server) KeyGen() { - s.privateKey = s.group.NewScalar().Random() - s.publicKey = s.group.Base().Multiply(s.privateKey) -} - -// Evaluate the input with the private key. -func (s *Server) Evaluate(blindedElement, info []byte) (*Evaluation, error) { - return s.innerEvaluateBatch([][]byte{blindedElement}, nil, info) -} - -// EvaluateWithRandom does the same as Evaluate and allows to provide a random input for proof generation. -func (s *Server) EvaluateWithRandom(blindedElement, random, info []byte) (*Evaluation, error) { - return s.innerEvaluateBatch([][]byte{blindedElement}, random, info) -} - -func (s *Server) getPrivateKeys(info []byte) (scalar, t *group.Scalar, err error) { - if s.mode == POPRF { - context := s.pTag(info) - t = s.privateKey.Copy().Add(context) - scalar = t.Copy().Invert() - - if scalar.IsZero() { - return nil, nil, errZeroScalar - } - } else { - scalar = s.privateKey - } - - return scalar, t, nil -} - -func setRandom(r *group.Scalar, random []byte) error { - if len(random) == 0 { - r.Random() - } else { - if err := r.Decode(random); err != nil { - return fmt.Errorf("decoding input random scalar: %w", err) - } - } - - return nil -} - -func (s *Server) innerEvaluateBatch(blindedElements [][]byte, random, info []byte) (*Evaluation, error) { - ev := &evaluation{ - proofC: nil, - proofS: nil, - elements: nil, - } - ev.elements = make([]*group.Element, len(blindedElements)) - - var blinded []*group.Element - var scalar, t *group.Scalar - - scalar, t, err := s.getPrivateKeys(info) - if err != nil { - return nil, err - } - - var r *group.Scalar - - if s.mode == VOPRF || s.mode == POPRF { - blinded = make([]*group.Element, len(blindedElements)) - - r = s.group.NewScalar() - if err := setRandom(r, random); err != nil { - return nil, err - } - } - - // decode and evaluate element(s) - for i, bytes := range blindedElements { - b := s.group.NewElement() - if err := b.Decode(bytes); err != nil { - return nil, fmt.Errorf("OPRF can't evaluate input : %w", err) - } - - if s.mode == VOPRF || s.mode == POPRF { - blinded[i] = b - } - - ev.elements[i] = b.Copy().Multiply(scalar) - } - - // generate proof - if s.mode == VOPRF { - ev.proofC, ev.proofS = s.oprf.generateProof(r, s.privateKey, s.publicKey, blinded, ev.elements) - } else if s.mode == POPRF { - tweakedKey := s.group.Base().Multiply(t) - ev.proofC, ev.proofS = s.oprf.generateProof(r, t, tweakedKey, ev.elements, blinded) - } - - return ev.serialize(), nil -} - -// EvaluateBatch evaluates the input batch of blindedElements and returns a pointer to the Evaluation. If the server -// was set to be un VOPRF mode, the proof will be included in the Evaluation. -func (s *Server) EvaluateBatch(blindedElements [][]byte, info []byte) (*Evaluation, error) { - return s.innerEvaluateBatch(blindedElements, nil, info) -} - -// EvaluateBatchWithRandom does the same as EvaluateBatch and allows to provide a random input for proof generation. -func (s *Server) EvaluateBatchWithRandom(blindedElements [][]byte, random, info []byte) (*Evaluation, error) { - return s.innerEvaluateBatch(blindedElements, random, info) -} - -// PrivateKey returns the server's serialized private key. -func (s *Server) PrivateKey() []byte { - return s.privateKey.Encode() -} - -// PublicKey returns the server's serialized public key. -func (s *Server) PublicKey() []byte { - return s.publicKey.Encode() -} - -// Ciphersuite returns the cipher suite used in the server's instance. -func (s *Server) Ciphersuite() Ciphersuite { - return s.oprf.ciphersuite -} diff --git a/tests/helper_test.go b/tests/helper_test.go index 93e88b4..d6bc521 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -11,7 +11,6 @@ package voprf_test import ( "crypto/elliptic" "crypto/rand" - "encoding/binary" "encoding/hex" "fmt" "log" @@ -21,7 +20,7 @@ import ( group "github.com/bytemare/crypto" "github.com/bytemare/hash" - "github.com/bytemare/voprf" + oprf "github.com/bytemare/voprf" ) func init() { @@ -32,8 +31,8 @@ func init() { type configuration struct { curve elliptic.Curve - ciphersuite voprf.Ciphersuite name string + ciphersuite oprf.Ciphersuite hash hash.Hash group group.Group } @@ -41,35 +40,35 @@ type configuration struct { var configurationTable = []configuration{ { name: "Ristretto255", - ciphersuite: voprf.Ristretto255Sha512, + ciphersuite: oprf.Ristretto255Sha512, group: group.Ristretto255Sha512, hash: hash.SHA512, curve: nil, }, { name: "P256Sha256", - ciphersuite: voprf.P256Sha256, + ciphersuite: oprf.P256Sha256, group: group.P256Sha256, hash: hash.SHA256, curve: elliptic.P256(), }, { name: "P384Sha512", - ciphersuite: voprf.P384Sha384, + ciphersuite: oprf.P384Sha384, group: group.P384Sha384, hash: hash.SHA384, curve: elliptic.P384(), }, { name: "P521Sha512", - ciphersuite: voprf.P521Sha512, + ciphersuite: oprf.P521Sha512, group: group.P521Sha512, hash: hash.SHA512, curve: elliptic.P521(), }, { name: "Secp256k1Sha256", - ciphersuite: voprf.Secp256k1, + ciphersuite: oprf.Secp256k1, group: group.Secp256k1, hash: hash.SHA256, curve: nil, @@ -137,7 +136,7 @@ func getBadNistElement(t *testing.T, g group.Group) []byte { func getBadElement(t *testing.T, c *configuration) []byte { switch c.ciphersuite { - case voprf.Ristretto255Sha512: + case oprf.Ristretto255Sha512: return getBadRistrettoElement() default: return getBadNistElement(t, c.ciphersuite.Group()) @@ -146,7 +145,7 @@ func getBadElement(t *testing.T, c *configuration) []byte { func getBadScalar(t *testing.T, c *configuration) []byte { switch c.ciphersuite { - case voprf.Ristretto255Sha512: + case oprf.Ristretto255Sha512: return getBadRistrettoScalar() default: return badScalar(t, c.ciphersuite.Group(), c.curve) @@ -154,79 +153,5 @@ func getBadScalar(t *testing.T, c *configuration) []byte { } const ( - deriveKeyPairDST = "DeriveKeyPair" hash2groupDSTPrefix = "HashToGroup-" ) - -func concatenate(input ...[]byte) []byte { - if len(input) == 1 { - if len(input[0]) == 0 { - return nil - } - - return input[0] - } - - length := 0 - for _, in := range input { - length += len(in) - } - - buf := make([]byte, 0, length) - - for _, in := range input { - buf = append(buf, in...) - } - - return buf -} - -func dst(prefix string, contextString []byte) []byte { - p := []byte(prefix) - t := make([]byte, 0, len(p)+len(contextString)) - t = append(t, p...) - t = append(t, contextString...) - - return t -} - -func i2osp2(value int) []byte { - out := make([]byte, 2) - binary.BigEndian.PutUint16(out, uint16(value)) - - return out -} - -func lengthPrefixEncode(input []byte) []byte { - return append(i2osp2(len(input)), input...) -} - -func contextString(mode voprf.Mode, g voprf.Ciphersuite) []byte { - ctx := make([]byte, 0, len(voprf.Version)+3+len(g.String())) - ctx = append(ctx, voprf.Version...) - ctx = append(ctx, "-"...) - ctx = append(ctx, byte(mode)) - ctx = append(ctx, "-"...) - ctx = append(ctx, g.String()...) - - return ctx -} - -func deriveKeyPair(seed, info []byte, mode voprf.Mode, g voprf.Ciphersuite) (*group.Scalar, *group.Element) { - dst := concatenate([]byte(deriveKeyPairDST), contextString(mode, g)) - deriveInput := concatenate(seed, lengthPrefixEncode(info)) - - var counter uint8 - var s *group.Scalar - - for s == nil || s.IsZero() { - if counter > 255 { - panic("impossible to generate non-zero scalar") - } - - s = g.Group().HashToScalar(concatenate(deriveInput, []byte{counter}), dst) - counter++ - } - - return s, g.Group().Base().Multiply(s) -} diff --git a/tests/state_test.go b/tests/state_test.go deleted file mode 100644 index c71e763..0000000 --- a/tests/state_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -package voprf_test - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "testing" - - "github.com/bytemare/voprf" -) - -func TestEvaluationSerde(t *testing.T) { - suite := voprf.Ristretto255Sha512 - input := []byte("input") - mode := voprf.OPRF - - server, err := suite.Server(mode, nil) - if err != nil { - t.Fatal(err) - } - - spk := server.PublicKey() - - client, err := suite.Client(mode, spk) - if err != nil { - t.Fatal(err) - } - - blinded := client.Blind(input, nil) - evaluation, err := server.Evaluate(blinded, nil) - if err != nil { - panic(err) - } - - ser := evaluation.Serialize() - deser := &voprf.Evaluation{} - - if err := deser.Deserialize(ser); err != nil { - t.Fatal(err) - } - - errSerDeFailed := errors.New("evaluation serde failed") - - if !areArraysOfArraysEqual(evaluation.Elements, deser.Elements) { - t.Fatal(errSerDeFailed) - } - - if bytes.Compare(evaluation.ProofC, evaluation.ProofC) != 0 { - t.Fatal(errSerDeFailed) - } - - if bytes.Compare(evaluation.ProofS, evaluation.ProofS) != 0 { - t.Fatal(errSerDeFailed) - } -} - -func serdeExport(t *testing.T, client *voprf.Client) (*voprf.State, *voprf.State) { - export := client.Export() - - serialized, err := json.Marshal(export) - if err != nil { - t.Fatal(err) - } - - state := &voprf.State{} - if err := json.Unmarshal(serialized, state); err != nil { - t.Fatal(err) - } - - return export, state -} - -func TestClientState(t *testing.T) { - suite := voprf.Ristretto255Sha512 - input := []byte("input") - kp := suite.KeyGen() // only used in VOPRF and POPRF - info := []byte("additional data") // only used in POPRF - - for _, mode := range []voprf.Mode{voprf.OPRF, voprf.VOPRF, voprf.POPRF} { - t.Run(fmt.Sprintf("State test for mode %v", mode), func(t *testing.T) { - client, err := suite.Client(mode, kp.PublicKey.Encode()) - if err != nil { - t.Fatal(err) - } - - client.Blind(input, info) - - export, state := serdeExport(t, client) - - resumed, err := state.RecoverClient() - if err != nil { - t.Fatal(err) - } - - export2 := resumed.Export() - - if !areStatesEqual(export, export2) { - t.Fatal("states are not equal") - } - }) - } -} - -func areArraysOfArraysEqual(a, b [][]byte) bool { - if len(a) != len(b) { - return false - } - - for i, c := range a { - if bytes.Compare(c, b[i]) != 0 { - return false - } - } - - return true -} - -func areStatesEqual(x1, x2 *voprf.State) bool { - if x1.Mode != x2.Mode { - return false - } - - if x1.Identifier != x2.Identifier { - return false - } - - if bytes.Compare(x1.TweakedKey, x1.TweakedKey) != 0 { - return false - } - - if bytes.Compare(x1.ServerPublicKey, x1.ServerPublicKey) != 0 { - return false - } - - if !areArraysOfArraysEqual(x1.Input, x2.Input) { - return false - } - - if !areArraysOfArraysEqual(x1.Blind, x2.Blind) { - return false - } - - if !areArraysOfArraysEqual(x1.Blinded, x2.Blinded) { - return false - } - - return true -} diff --git a/tests/utils_test.go b/tests/utils_test.go deleted file mode 100644 index f09fa0e..0000000 --- a/tests/utils_test.go +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -package voprf_test diff --git a/tests/vectors_test.go b/tests/vectors_test.go index 2d8b208..ea3a932 100644 --- a/tests/vectors_test.go +++ b/tests/vectors_test.go @@ -20,24 +20,27 @@ import ( group "github.com/bytemare/crypto" "github.com/bytemare/hash" - "github.com/bytemare/voprf" + oprf "github.com/bytemare/voprf" + "github.com/bytemare/voprf/internal" + "github.com/bytemare/voprf/voprf" ) type test struct { - Blind [][]byte - BlindedElement [][]byte + ServerPrivateKey *group.Scalar + ProofC *group.Scalar + NonceR *group.Scalar + ProofS *group.Scalar + Blind []*group.Scalar + BlindedElement []*group.Element Info []byte - EvaluationElement [][]byte - ProofC []byte - NonceR []byte - ProofS []byte + EvaluationElement []*group.Element Input [][]byte Output [][]byte Batch int + oprf.Ciphersuite } type testVector struct { - ID voprf.Ciphersuite `json:"proof,omitempty"` EvaluationProof struct { Proof string `json:"proof,omitempty"` Random string `json:"r,omitempty"` @@ -49,6 +52,7 @@ type testVector struct { Input string `json:"Input"` Output string `json:"Output"` Batch int `json:"Batch"` + Ciphersuite oprf.Ciphersuite } func decodeBatch(nb int, in string) ([][]byte, error) { @@ -70,44 +74,93 @@ func decodeBatch(nb int, in string) ([][]byte, error) { return out, nil } -func (t *test) Verify(suite voprf.Ciphersuite) error { - g := suite.Group() +func decodeBatchScalar(g group.Group, nb int, in string) ([]*group.Scalar, error) { + b, err := decodeBatch(nb, in) + if err != nil { + return nil, err + } - for i, b := range t.Blind { - if err := g.NewScalar().Decode(b); err != nil { - return fmt.Errorf("blind %d decoding: %w", i, err) + res := make([]*group.Scalar, nb) + for i, bi := range b { + res[i] = g.NewScalar() + if err := res[i].Decode(bi); err != nil { + return nil, err } } - for i, b := range t.BlindedElement { - if err := g.NewElement().Decode(b); err != nil { - return fmt.Errorf("blinded element %d decoding: %w", i, err) - } + return res, nil +} + +func decodeBatchElement(g group.Group, nb int, in string) ([]*group.Element, error) { + b, err := decodeBatch(nb, in) + if err != nil { + return nil, err } - for i, b := range t.EvaluationElement { - if err := g.NewElement().Decode(b); err != nil { - return fmt.Errorf("evaluation element %d decoding: %w", i, err) + res := make([]*group.Element, nb) + for i, bi := range b { + res[i] = g.NewElement() + if err := res[i].Decode(bi); err != nil { + return nil, err } } - return nil + return res, nil +} + +//func (t *test) Verify(suite oprf.Ciphersuite) error { +// g := suite.Group() +// +// for i, b := range t.Blind { +// if err := g.NewScalar().Decode(b); err != nil { +// return fmt.Errorf("blind %d decoding: %w", i, err) +// } +// } +// +// for i, b := range t.BlindedElement { +// if err := g.NewElement().Decode(b); err != nil { +// return fmt.Errorf("blinded element %d decoding: %w", i, err) +// } +// } +// +// for i, b := range t.EvaluationElement { +// if err := g.NewElement().Decode(b); err != nil { +// return fmt.Errorf("evaluation element %d decoding: %w", i, err) +// } +// } +// +// return nil +//} + +func decodeScalar(g group.Group, s string) (*group.Scalar, error) { + ds, err := hex.DecodeString(s) + if err != nil { + return nil, fmt.Errorf(" ProofC decoding errored with %q", err) + } + + out := g.NewScalar() + if err := out.Decode(ds); err != nil { + return nil, err + } + + return out, nil } func (tv *testVector) Decode() (*test, error) { - blind, err := decodeBatch(tv.Batch, tv.Blind) + g := tv.Ciphersuite.Group() + blind, err := decodeBatchScalar(g, tv.Batch, tv.Blind) // blind, err := hex.DecodeString(tv.Blind) if err != nil { return nil, fmt.Errorf(" Blind decoding errored with %q", err) } - blinded, err := decodeBatch(tv.Batch, tv.BlindedElement) + blinded, err := decodeBatchElement(g, tv.Batch, tv.BlindedElement) // blinded, err := hex.DecodeString(tv.BlindedElement) if err != nil { return nil, fmt.Errorf(" BlindedElement decoding errored with %q", err) } - evaluationElement, err := decodeBatch(tv.Batch, tv.EvaluationElement) + evaluationElement, err := decodeBatchElement(g, tv.Batch, tv.EvaluationElement) if err != nil { return nil, fmt.Errorf(" EvaluationElement decoding errored with %q", err) } @@ -117,23 +170,23 @@ func (tv *testVector) Decode() (*test, error) { return nil, fmt.Errorf(" info decoding errored with %q", err) } - var proofC, nonceR, proofS []byte + var proofC, nonceR, proofS *group.Scalar if len(tv.EvaluationProof.Proof) != 0 { pLen := len(tv.EvaluationProof.Proof) c := tv.EvaluationProof.Proof[:pLen/2] s := tv.EvaluationProof.Proof[pLen/2:] - proofC, err = hex.DecodeString(c) + proofC, err = decodeScalar(tv.Ciphersuite.Group(), c) if err != nil { return nil, fmt.Errorf(" ProofC decoding errored with %q", err) } - proofS, err = hex.DecodeString(s) + proofS, err = decodeScalar(tv.Ciphersuite.Group(), s) if err != nil { return nil, fmt.Errorf(" ProofS decoding errored with %q", err) } - nonceR, err = hex.DecodeString(tv.EvaluationProof.Random) + nonceR, err = decodeScalar(tv.Ciphersuite.Group(), tv.EvaluationProof.Random) if err != nil { return nil, fmt.Errorf(" NonceR decoding errored with %q", err) } @@ -152,6 +205,7 @@ func (tv *testVector) Decode() (*test, error) { } return &test{ + Ciphersuite: tv.Ciphersuite, Batch: tv.Batch, Blind: blind, BlindedElement: blinded, @@ -168,15 +222,15 @@ func (tv *testVector) Decode() (*test, error) { type vectors []vector type vector struct { - DST string `json:"groupDST"` - Hash string `json:"hash"` - KeyInfo string `json:"keyInfo"` - SksSeed string `json:"seed"` - PkSm string `json:"pkSm,omitempty"` - SkSm string `json:"skSm"` - SuiteID voprf.Ciphersuite `json:"identifier"` - TestVectors []testVector `json:"vectors,omitempty"` - Mode voprf.Mode `json:"mode"` + DST string `json:"groupDST"` + Hash string `json:"hash"` + KeyInfo string `json:"keyInfo"` + SksSeed string `json:"seed"` + PkSm string `json:"pkSm,omitempty"` + SkSm string `json:"skSm"` + SuiteID string `json:"identifier"` + TestVectors []testVector `json:"vectors,omitempty"` + Mode internal.Mode `json:"mode"` } func hashToHash(h string) hash.Hash { @@ -206,7 +260,7 @@ func hashToHash(h string) hash.Hash { func (v vector) checkParams(t *testing.T) { // Check mode - if v.Mode != voprf.OPRF && v.Mode != voprf.VOPRF && v.Mode != voprf.POPRF { + if v.Mode != internal.OPRF && v.Mode != internal.VOPRF && v.Mode != internal.POPRF { t.Fatalf("invalid mode %v", v.Mode) } @@ -226,75 +280,131 @@ func (v vector) checkParams(t *testing.T) { //} } -func testBlind(t *testing.T, ciphersuite voprf.Ciphersuite, client *voprf.Client, input, blind, expected, info []byte) { - s := ciphersuite.Group().NewScalar() - if err := s.Decode(blind); err != nil { - t.Fatal(fmt.Errorf("blind decoding to scalar in suite %v errored with %q", ciphersuite, err)) - } - - client.SetBlinds([]*group.Scalar{s}) +type Client interface { + Blind(input []byte) *group.Element + BlindBatch(inputs [][]byte) []*group.Element + SetBlind(blind ...*group.Scalar) +} - blinded := client.Blind(input, info) +func testBlind(t *testing.T, client Client, blind *group.Scalar, input []byte, expected *group.Element) { + client.SetBlind(blind) + blinded := client.Blind(input) - if !bytes.Equal(expected, blinded) { + if blinded.Equal(expected) != 1 { t.Fatal("unexpected blinded output") } } -func testBlindBatchWithBlinds(t *testing.T, client *voprf.Client, inputs, blinds, outputs [][]byte, info []byte) { - blinded, err := client.BlindBatchWithBlinds(blinds, inputs, info) - if err != nil { - t.Fatal(err) +func testBlindBatch(t *testing.T, client Client, blinds []*group.Scalar, inputs [][]byte, expected []*group.Element) { + client.SetBlind(blinds...) + + blinded := client.BlindBatch(inputs) + if len(blinded) != len(expected) { + t.Fatal("different number of blinded elements than expected") } - for i, o := range outputs { - if !bytes.Equal(o, blinded[i]) { - t.Fatal("unexpected blinded output") + for i, b := range expected { + if b.Equal(expected[i]) != 1 { + t.Fatalf("unexpected blinded output %d", i) } } } -func testOPRFServerEvaluation(t *testing.T, server *voprf.Server, test *test) *voprf.Evaluation { - var ev *voprf.Evaluation - var err error +func testOPRFEvaluation(t *testing.T, test *test) { + if len(test.BlindedElement) > 1 { + ev := oprf.Evaluate(test.ServerPrivateKey, test.BlindedElement[0]) - if test.Batch == 1 { - ev, err = server.EvaluateWithRandom(test.BlindedElement[0], test.NonceR, test.Info) - if err != nil { - t.Fatal(err) + if test.EvaluationElement[0].Equal(ev) != 1 { + t.Fatal("unexpected evaluation element") + } + } else { + ev := oprf.EvaluateBatch(test.ServerPrivateKey, test.BlindedElement) + + if len(ev) != len(test.BlindedElement) { + t.Fatal("unequal length") } - if !bytes.Equal(test.EvaluationElement[0], ev.Elements[0]) { - t.Fatal("unexpected evaluation element") + for i, e := range ev { + if test.EvaluationElement[i].Equal(e) != 1 { + t.Fatal("unexpected evaluation element") + } + } + } +} + +func testOPRFFinalize(t *testing.T, client *oprf.Client, test *test) { + if test.Batch == 1 { + output := client.Finalize(test.EvaluationElement[0]) + + if !bytes.Equal(test.Output[0], output) { + t.Fatal("finalize() output is not valid.") } } else { - ev, err = server.EvaluateBatchWithRandom(test.BlindedElement, test.NonceR, test.Info) + output, err := client.FinalizeBatch(test.EvaluationElement) if err != nil { t.Fatal(err) } + for i, o := range test.Output { + if !bytes.Equal(o, output[i]) { + t.Fatal("finalizeBatch() output is not valid.") + } + } + } +} + +func testVPOPRFEvaluation(t *testing.T, server *voprf.Server, test *test) { + var evaluation *voprf.Evaluation + if test.Batch == 1 { + evaluation = server.Evaluate(test.BlindedElement[0], test.NonceR) + + if evaluation.Evaluations[0].Equal(test.EvaluationElement[0]) != 1 { + t.Fatalf( + "unexpected evaluation element:\n\twant: %v\n\tgot : %v\n", + hex.EncodeToString(test.EvaluationElement[0].Encode()), + hex.EncodeToString(evaluation.Evaluations[0].Encode()), + ) + } + } else { + evaluation = server.EvaluateBatch(test.BlindedElement, test.NonceR) + for i, e := range test.EvaluationElement { - if !bytes.Equal(e, ev.Elements[i]) { + if e.Equal(evaluation.Evaluations[i]) != 1 { t.Fatal("unexpected evaluation elements") } } } - return ev + if evaluation.Proof[0].Equal(test.ProofC) != 1 { + t.Fatal("unexpected proof c") + } + + if evaluation.Proof[1].Equal(test.ProofS) != 1 { + t.Fatal("unexpected proof s") + } } -func testOPRFClientFinalize(t *testing.T, client *voprf.Client, ev *voprf.Evaluation, test *test) { +func testVPOPRFFinalize(t *testing.T, client *voprf.Client, test *test) { + evaluation := &voprf.Evaluation{ + Proof: [2]*group.Scalar{test.ProofC, test.ProofS}, + Evaluations: test.EvaluationElement, + } + if test.Batch == 1 { - output, err := client.Finalize(ev, test.Info) + output, err := client.Finalize(evaluation) if err != nil { t.Fatal(err) } if !bytes.Equal(test.Output[0], output) { - t.Fatal("finalize() output is not valid.") + t.Fatalf( + "finalize() output is not valid.\n\twant: %s\n\tgot : %s\n", + hex.EncodeToString(test.Output[0]), + hex.EncodeToString(output), + ) } } else { - output, err := client.FinalizeBatch(ev, test.Info) + output, err := client.FinalizeBatch(evaluation) if err != nil { t.Fatal(err) } @@ -309,59 +419,68 @@ func testOPRFClientFinalize(t *testing.T, client *voprf.Client, ev *voprf.Evalua func testOPRF( t *testing.T, - ciphersuite voprf.Ciphersuite, - mode voprf.Mode, - client *voprf.Client, - server *voprf.Server, test *test, ) { + client := test.Ciphersuite.Client() + // OPRFClient Blinding if test.Batch == 1 { - testBlind(t, ciphersuite, client, test.Input[0], test.Blind[0], test.BlindedElement[0], test.Info) + testBlind(t, client, test.Blind[0], test.Input[0], test.BlindedElement[0]) } else { - testBlindBatchWithBlinds(t, client, test.Input, test.Blind, test.BlindedElement, test.Info) + testBlindBatch(t, client, test.Blind, test.Input, test.BlindedElement) } // OPRFServer evaluating - ev := testOPRFServerEvaluation(t, server, test) - - // Verify proofs - if mode == voprf.VOPRF || mode == voprf.POPRF { - if !bytes.Equal(test.ProofC, ev.ProofC) { - t.Errorf( - "unexpected c proof\n\twant %v\n\tgot %v", - hex.EncodeToString(test.ProofC), - hex.EncodeToString(ev.ProofC), - ) - } + testOPRFEvaluation(t, test) - if !bytes.Equal(test.ProofS, ev.ProofS) { - t.Errorf( - "unexpected s proof\n\twant %v\n\tgot %v", - hex.EncodeToString(test.ProofS), - hex.EncodeToString(ev.ProofS), - ) - } + // OPRFClient finalize + testOPRFFinalize(t, client, test) +} + +func testVPOPRF( + t *testing.T, + test *test, +) { + sk, pk := test.ServerPrivateKey, test.Ciphersuite.Group().Base().Multiply(test.ServerPrivateKey) + server := voprf.NewServer(test.Ciphersuite, test.Info...) + if err := server.SetKeyPair(sk, pk); err != nil { + t.Fatal(err) + } + + client, err := voprf.NewClient(test.Ciphersuite, pk, test.Info...) + if err != nil { + t.Fatal(err) } + // OPRFClient Blinding + if test.Batch == 1 { + testBlind(t, client, test.Blind[0], test.Input[0], test.BlindedElement[0]) + } else { + testBlindBatch(t, client, test.Blind, test.Input, test.BlindedElement) + } + + // OPRFServer evaluating + testVPOPRFEvaluation(t, server, test) + // OPRFClient finalize - testOPRFClientFinalize(t, client, ev, test) + testVPOPRFFinalize(t, client, test) } func (v vector) testVector( t *testing.T, - tv *testVector, - suite voprf.Ciphersuite, - mode voprf.Mode, - privKey, serverPublicKey, expectedDST []byte, + test *test, ) { - test, err := tv.Decode() + expectedDST, err := hex.DecodeString(v.DST) if err != nil { - t.Fatal(fmt.Sprintf("batches : %v Failed %v\n", tv.Batch, err)) + t.Fatalf("hex decoding errored with %q", err) } - if err := test.Verify(suite); err != nil { - t.Fatal(err) + if string( + expectedDST, + ) != string( + internal.Dst(hash2groupDSTPrefix, internal.ContextString(v.Mode, test.Ciphersuite.Name())), + ) { + t.Fatal("GroupDST output is not valid.") } // Test DeriveKeyPair @@ -375,72 +494,67 @@ func (v vector) testVector( t.Fatal(err) } - sks, _ := deriveKeyPair(seed, keyInfo, mode, suite) - // log.Printf("sks %v", hex.EncodeToString(serializeScalar(sks, scalarLength(o.id)))) - if !bytes.Equal(sks.Encode(), privKey) { - t.Fatalf("DeriveKeyPair yields unexpected output\n\twant: %v\n\tgot : %v", privKey, sks.Encode()) - } - - // Set up a new server. - server, err := suite.Server(mode, privKey) + privKey, err := hex.DecodeString(v.SkSm) if err != nil { - t.Fatalf( - "failed on setting up server %q\nvector value (%d) %v\ndecoded (%d) %v\n", - err, - len(v.SkSm), - v.SkSm, - len(privKey), - privKey, - ) + t.Fatalf("private key decoding errored with %q\nfor sksm %v\n", err, v.SkSm) } - if string(expectedDST) != string(dst(hash2groupDSTPrefix, contextString(mode, suite))) { - t.Fatal("GroupDST output is not valid.") - } + var sks *group.Scalar - client, err := suite.Client(mode, serverPublicKey) - if err != nil { - t.Fatal(err) + if v.Mode == internal.OPRF { + sks, _ = oprf.DeriveKeyPair(test.Ciphersuite, seed, keyInfo) + } else { + server := voprf.NewServer(test.Ciphersuite, test.Info...) + server.DeriveKeyPair(seed, keyInfo) + sks, _ = server.KeyPair() } - if string(expectedDST) != string(dst(hash2groupDSTPrefix, contextString(mode, suite))) { - t.Fatal("GroupDST output is not valid.") + if !bytes.Equal(sks.Encode(), privKey) { + t.Fatalf("DeriveKeyPair yields unexpected output\n\twant: %v\n\tgot : %v", privKey, sks.Encode()) } + test.ServerPrivateKey = sks + // test protocol execution - testOPRF(t, v.SuiteID, mode, client, server, test) + if v.Mode == internal.OPRF { + testOPRF(t, test) + } else { + testVPOPRF(t, test) + } +} + +func suiteToCiphersuite(t *testing.T, s string) oprf.Ciphersuite { + switch s { + case "ristretto255-SHA512": + return oprf.Ristretto255Sha512 + case "decaf448-SHAKE256": + t.Fatal("decaf not supported") + case "P256-SHA256": + return oprf.P256Sha256 + case "P384-SHA384": + return oprf.P384Sha384 + case "P521-SHA512": + return oprf.P521Sha512 + } + + t.Fatalf("unknown suite: %s", s) + return 0 } func (v vector) test(t *testing.T) { // Check mode, hash function, and cipher suite v.checkParams(t) - // Get mode, hash function, and cipher suite - mode := v.Mode - suite := v.SuiteID - - privKey, err := hex.DecodeString(v.SkSm) - if err != nil { - t.Fatalf("private key decoding errored with %q\nfor sksm %v\n", err, v.SkSm) - } - - var serverPublicKey []byte - if mode == voprf.VOPRF || mode == voprf.POPRF { - pksm, err := hex.DecodeString(v.PkSm) - if err != nil { - t.Fatalf("error decoding public key %v", err) - } - serverPublicKey = pksm - } - - expectedDST, err := hex.DecodeString(v.DST) - if err != nil { - t.Fatalf("hex decoding errored with %q", err) - } - for i, tv := range v.TestVectors { t.Run(fmt.Sprintf("Vector %d", i), func(t *testing.T) { - v.testVector(t, &tv, suite, mode, privKey, serverPublicKey, expectedDST) + tv.Ciphersuite = suiteToCiphersuite(t, v.SuiteID) + + test, err := tv.Decode() + if err != nil { + t.Fatal(fmt.Sprintf("batches : %v Failed %v\n", tv.Batch, err)) + } + + v.testVector(t, test) }) } } @@ -473,6 +587,6 @@ func TestVOPRFVectors(t *testing.T) { continue } - t.Run(string(tv.Mode)+" - "+string(tv.SuiteID), tv.test) + t.Run(string(tv.Mode)+" - "+tv.SuiteID, tv.test) } } diff --git a/tests/voprf_test.go b/tests/voprf_test.go index d2fa8ba..2d6a6e5 100644 --- a/tests/voprf_test.go +++ b/tests/voprf_test.go @@ -13,20 +13,22 @@ import ( "errors" "testing" - "github.com/bytemare/voprf" + oprf "github.com/bytemare/voprf" + "github.com/bytemare/voprf/voprf" ) var errExpectedEquality = errors.New("expected equality") -func makeClientAndServer(t *testing.T, mode voprf.Mode, ciphersuite voprf.Ciphersuite) (*voprf.Client, *voprf.Server) { - server, err := ciphersuite.Server(mode, nil) - if err != nil { +func makeVPClientAndServer(t *testing.T, ciphersuite oprf.Ciphersuite, info []byte) (*voprf.Client, *voprf.Server) { + sk := ciphersuite.Group().NewScalar().Random() + pk := ciphersuite.Group().Base().Multiply(sk) + + server := voprf.NewServer(ciphersuite, info...) + if err := server.SetKeyPair(sk, pk); err != nil { t.Fatal(err) } - spk := server.PublicKey() - - client, err := ciphersuite.Client(mode, spk) + client, err := voprf.NewClient(ciphersuite, pk, info...) if err != nil { t.Fatal(err) } @@ -34,53 +36,58 @@ func makeClientAndServer(t *testing.T, mode voprf.Mode, ciphersuite voprf.Cipher return client, server } -func runOPRF(t *testing.T, c *configuration, mode voprf.Mode, input, info []byte) *voprf.Evaluation { - client, server := makeClientAndServer(t, mode, c.ciphersuite) +func TestOPRF(t *testing.T) { + input := []byte("input") + + testAll(t, func(c *configuration) { + serverKey := c.group.NewScalar().Random() + client := c.ciphersuite.Client() + blinded := client.Blind(input) + evaluated := oprf.Evaluate(serverKey, blinded) + _ = client.Finalize(evaluated) + }) +} - blinded := client.Blind(input, info) +func doVPOPRF(t *testing.T, input, info []byte, c *configuration) { + serverKey := c.group.NewScalar().Random() + serverPubkey := c.group.Base().Multiply(serverKey) - evaluation, err := server.Evaluate(blinded, info) - if err != nil { + server := voprf.NewServer(c.ciphersuite, info...) + if err := server.SetKeyPair(serverKey, serverPubkey); err != nil { t.Fatal(err) } - if _, err = client.Finalize(evaluation, info); err != nil { + client, err := voprf.NewClient(c.ciphersuite, serverPubkey, info...) + if err != nil { t.Fatal(err) } - return evaluation -} - -func TestOPRF(t *testing.T) { - mode := voprf.OPRF - input := []byte("input") - - testAll(t, func(c *configuration) { - _ = runOPRF(t, c, mode, input, nil) - }) + blinded := client.Blind(input) + evaluation := server.Evaluate(blinded) + _, err = client.Finalize(evaluation) + if err != nil { + t.Fatal(err) + } } func TestVOPRF(t *testing.T) { - mode := voprf.VOPRF input := []byte("input") testAll(t, func(c *configuration) { - _ = runOPRF(t, c, mode, input, nil) + doVPOPRF(t, input, nil, c) }) } func TestPOPRF(t *testing.T) { - mode := voprf.POPRF info := []byte("info") input := []byte("input") testAll(t, func(c *configuration) { - _ = runOPRF(t, c, mode, input, info) + doVPOPRF(t, input, info, c) }) } func TestBatching(t *testing.T) { - mode := voprf.POPRF info := []byte("info") inputs := [][]byte{ []byte("input1"), @@ -89,29 +96,13 @@ func TestBatching(t *testing.T) { } testAll(t, func(c *configuration) { - client, server := makeClientAndServer(t, mode, c.ciphersuite) + client, server := makeVPClientAndServer(t, c.ciphersuite, info) + blinded := client.BlindBatch(inputs) + evaluation := server.EvaluateBatch(blinded) - _, blinded, err := client.BlindBatch(inputs, info) - if err != nil { + if _, err := client.FinalizeBatch(evaluation); err != nil { t.Fatal(err) } - - evaluation, err := server.EvaluateBatch(blinded, info) - if err != nil { - t.Fatal(err) - } - - if _, err = client.FinalizeBatch(evaluation, info); err != nil { - t.Fatal(err) - } - }) -} - -func TestAvailability(t *testing.T) { - testAll(t, func(c *configuration) { - if !c.ciphersuite.Available() { - t.Fatal("expected availability") - } }) } @@ -121,10 +112,7 @@ func TestCiphersuiteGroup(t *testing.T) { t.Fatal(errExpectedEquality) } - ciphersuite, err := voprf.FromGroup(c.group) - if err != nil { - t.Fatal(err) - } + ciphersuite := oprf.FromGroup(c.group) if ciphersuite != c.ciphersuite { t.Fatal(errExpectedEquality) @@ -132,43 +120,9 @@ func TestCiphersuiteGroup(t *testing.T) { }) } -func TestCiphersuiteHashes(t *testing.T) { - testAll(t, func(c *configuration) { - if c.hash != c.ciphersuite.Hash() { - t.Fatal(errExpectedEquality) - } - }) -} - -func TestServerKeys(t *testing.T) { - mode := voprf.OPRF - - testAll(t, func(c *configuration) { - server, err := c.ciphersuite.Server(mode, nil) - if err != nil { - t.Fatal(err) - } - - private := c.ciphersuite.Group().NewScalar() - if err = private.Decode(server.PrivateKey()); err != nil { - t.Fatal(err) - } - - public := c.ciphersuite.Group().NewElement() - if err = public.Decode(server.PublicKey()); err != nil { - t.Fatal(err) - } - - pk := c.ciphersuite.Group().Base().Multiply(private) - if pk.Equal(public) != 1 { - t.Fatal(errExpectedEquality) - } - }) -} - func TestDeriveKeyPair(t *testing.T) { info := []byte("some instance") - ciphersuite := voprf.Ristretto255Sha512 + ciphersuite := oprf.Ristretto255Sha512 random, _ := hex.DecodeString("c332260baab120459e7ad1d47ce5a43f980abe9c19ecc0550bbd0dde58a548bf") encodedReferenceSecretKeyR255, _ := hex.DecodeString( @@ -184,9 +138,9 @@ func TestDeriveKeyPair(t *testing.T) { refPk := ciphersuite.Group().NewElement() _ = refPk.Decode(encodedReferencePublicKeyR255) - keyPair := ciphersuite.DeriveKeyPair(voprf.OPRF, random, info) + sk, pk := oprf.DeriveKeyPair(ciphersuite, random, info) - if keyPair.SecretKey.Equal(refSk) != 1 || keyPair.PublicKey.Equal(refPk) != 1 { + if sk.Equal(refSk) != 1 || pk.Equal(refPk) != 1 { t.Fatal(errExpectedEquality) } } diff --git a/verifiable.go b/verifiable.go deleted file mode 100644 index 7c3152c..0000000 --- a/verifiable.go +++ /dev/null @@ -1,117 +0,0 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree or at -// https://spdx.org/licenses/MIT.html - -package voprf - -import ( - group "github.com/bytemare/crypto" -) - -const ( - dstComposite = "Composite" - dstChallenge = "Challenge" - dstFinalize = "Finalize" - dstSeed = "Seed-" - dstInfo = "Info" -) - -func (o *oprf) ccScalar(encSeed []byte, index int, ci, di *group.Element) *group.Scalar { - input := concatenate(encSeed, i2osp2(index), - lengthPrefixEncode(ci.Encode()), - lengthPrefixEncode(di.Encode()), - []byte(dstComposite)) - - return o.HashToScalar(input) -} - -func (o *oprf) computeCompositesFast(k *group.Scalar, encSeed []byte, cs, ds []*group.Element) (m, z *group.Element) { - m = o.group.NewElement().Identity() - - for i, ci := range cs { - di := o.ccScalar(encSeed, i, ci, ds[i]) - m = ci.Copy().Multiply(di).Add(m) - } - - return m, m.Copy().Multiply(k) -} - -func (o *oprf) computeCompositesClient(encSeed []byte, cs, ds []*group.Element) (m, z *group.Element) { - m = o.group.NewElement().Identity() - z = o.group.NewElement().Identity() - - for i, ci := range cs { - di := o.ccScalar(encSeed, i, ci, ds[i]) - m = ci.Copy().Multiply(di).Add(m) - z = ds[i].Copy().Multiply(di).Add(z) - } - - return m, z -} - -func (o *oprf) computeComposites(k *group.Scalar, encGk []byte, cs, ds []*group.Element) (m, z *group.Element) { - // DST - encSeedDST := lengthPrefixEncode(dst(dstSeed, o.contextString)) - - // build seed - seed := o.hash.Hash(0, encGk, encSeedDST) - encSeed := lengthPrefixEncode(seed) - - // This means where calling from the server, and can optimize computation of Z, since Zi = sks * Mi - if k != nil { - return o.computeCompositesFast(k, encSeed, cs, ds) - } - - return o.computeCompositesClient(encSeed, cs, ds) -} - -func (o *oprf) challenge(encPks []byte, a0, a1, a2, a3 *group.Element) *group.Scalar { - encA0 := lengthPrefixEncode(a0.Encode()) - encA1 := lengthPrefixEncode(a1.Encode()) - encA2 := lengthPrefixEncode(a2.Encode()) - encA3 := lengthPrefixEncode(a3.Encode()) - encDST := []byte(dstChallenge) - input := concatenate(encPks, encA0, encA1, encA2, encA3, encDST) - - return o.HashToScalar(input) -} - -func (o *oprf) generateProof( - random, k *group.Scalar, - pk *group.Element, - cs, ds []*group.Element, -) (proofC, proofS *group.Scalar) { - encPk := lengthPrefixEncode(pk.Encode()) - a0, a1 := o.computeComposites(k, encPk, cs, ds) - - a2 := o.group.Base().Multiply(random) - a3 := a0.Copy().Multiply(random) - - proofC = o.challenge(encPk, a0, a1, a2, a3) - proofS = random.Subtract(proofC.Copy().Multiply(k)) - - return proofC, proofS -} - -func (o *oprf) verifyProof(ev *evaluation, pk *group.Element, cs, ds []*group.Element) error { - encGk := lengthPrefixEncode(pk.Encode()) - a0, a1 := o.computeComposites(nil, encGk, cs, ds) - - ap := pk.Copy().Multiply(ev.proofC) - a2 := o.group.Base().Multiply(ev.proofS).Add(ap) - - bm := a0.Copy().Multiply(ev.proofS) - bz := a1.Copy().Multiply(ev.proofC) - a3 := bm.Add(bz) - expectedC := o.challenge(encGk, a0, a1, a2, a3) - - if !ctEqual(expectedC.Encode(), ev.proofC.Encode()) { - return errProofFailed - } - - return nil -} diff --git a/voprf/client.go b/voprf/client.go new file mode 100644 index 0000000..03a69d5 --- /dev/null +++ b/voprf/client.go @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +package voprf + +import ( + "errors" + + group "github.com/bytemare/crypto" + + "github.com/bytemare/voprf" + "github.com/bytemare/voprf/internal" +) + +var ( + errInvalidPublicKey = errors.New("server public key is either nil or the identity element") + errInputNotSet = errors.New("no prior input found") + errDifferentSize = errors.New("number of evaluations differs from number of previously blinded elements") + errInputNilEval = errors.New("provided evaluation is nil") + errInputNoEval = errors.New("provided evaluation does not contain evaluations") + errInputProofCNil = errors.New("proof c is nil") + errInputProofCZero = errors.New("proof c is zero") + errInputProofSNil = errors.New("proof s is nil") + errInputProofSZero = errors.New("proof s is zero") +) + +// Client is used for VOPRF or POPRF client executions. For OPRF or TOPRF, used oprf.Client. +type Client struct { + oprf *voprf.Client + verifiable *internal.Verifiable + serverPublicKey *group.Element + tweakedKey *group.Element + blindedInput []*group.Element +} + +// NewClient returns a client given the ciphersuite and the server's public key. poprfInfo must only be provided if +// the POPRF mode is requested. If poprfInfo is not provided or nil, the VOPRF mode is used. +func NewClient(cs voprf.Ciphersuite, serverPublicKey *group.Element, poprfInfo ...byte) (*Client, error) { + if serverPublicKey == nil || serverPublicKey.IsIdentity() { + return nil, errInvalidPublicKey + } + + mode := internal.VOPRF + + // If info is given, then a POPRF is requested by the caller. + if len(poprfInfo) != 0 { + mode = internal.POPRF + } + + c := internal.NewClient(mode, group.Group(cs)) + + client := &Client{ + oprf: &voprf.Client{ + Client: c, + }, + verifiable: internal.NewVerifiable(c.Core, poprfInfo), + serverPublicKey: serverPublicKey, + tweakedKey: nil, + blindedInput: nil, + } + + if mode == internal.POPRF { + client.tweakedKey = client.verifiable.TweakPublicKey(serverPublicKey) + } + + return client, nil +} + +// SetBlind sets one or multiple blinds in the client's blind register. This is optional, and useful if you want to +// force usage of specific blinding scalar. If no blinding scalars are set, new, random blinds will be used. +func (c *Client) SetBlind(blind ...*group.Scalar) { + c.oprf.SetBlind(blind...) +} + +// Blind blinds the input using the first blinding scalar in the Client's register. If no blinding scalars were +// previously set, new, random blinds will be used. +func (c *Client) Blind(input []byte) *group.Element { + c.blindedInput = make([]*group.Element, 1) + c.blindedInput[0] = c.oprf.Blind(input) + + return c.blindedInput[0] +} + +// BlindBatch blinds the given set, using either previously set blinds in the same order (if they have been set) or +// newly generated random blinds. Note that if not enough blinds were set, new, random blinds will be used as necessary. +func (c *Client) BlindBatch(inputs [][]byte) []*group.Element { + c.blindedInput = c.oprf.BlindBatch(inputs) + return c.blindedInput +} + +func (c *Client) verifyProof(evaluation *Evaluation) error { + var pk *group.Element + var cs, ds []*group.Element + + if len(c.blindedInput) == 0 { + return errInputNotSet + } + + if c.oprf.Mode == internal.VOPRF { + cs, ds = c.blindedInput, evaluation.Evaluations + pk = c.serverPublicKey + } else { // POPRF + cs, ds = evaluation.Evaluations, c.blindedInput + pk = c.tweakedKey + } + + return c.verifiable.VerifyProof(evaluation.Proof[0], evaluation.Proof[1], pk, cs, ds) +} + +func (c *Client) checkEvaluation(evaluation *Evaluation) error { + if evaluation == nil { + return errInputNilEval + } + + if len(evaluation.Evaluations) == 0 { + return errInputNoEval + } + + if evaluation.Proof[0] == nil { + return errInputProofCNil + } + + if evaluation.Proof[0].IsZero() { + return errInputProofCZero + } + + if evaluation.Proof[1] == nil { + return errInputProofSNil + } + + if evaluation.Proof[1].IsZero() { + return errInputProofSZero + } + + if len(evaluation.Evaluations) != len(c.blindedInput) { + return errDifferentSize + } + + return nil +} + +// Finalize verifies the Server provided proofs, and, if they are valid, unblinds the evaluated element and returns +// the protocol output. +func (c *Client) Finalize(evaluation *Evaluation) ([]byte, error) { + if err := c.checkEvaluation(evaluation); err != nil { + return nil, err + } + + if err := c.verifyProof(evaluation); err != nil { + return nil, err + } + + return c.oprf.Client.Finalize(0, evaluation.Evaluations[0], c.verifiable.POPRFInfo...), nil +} + +// FinalizeBatch verifies the Server provided proofs, and, if they are valid, unblinds the evaluated elements and +// returns the protocol output. +func (c *Client) FinalizeBatch(evaluation *Evaluation) ([][]byte, error) { + if err := c.checkEvaluation(evaluation); err != nil { + return nil, err + } + + if err := c.verifyProof(evaluation); err != nil { + return nil, err + } + + return c.oprf.Client.FinalizeBatch(evaluation.Evaluations, c.verifiable.POPRFInfo...) +} diff --git a/voprf/doc.go b/voprf/doc.go new file mode 100644 index 0000000..68c814b --- /dev/null +++ b/voprf/doc.go @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +// Package voprf implements RFC9497 and provides abstracted access to Verifiable Oblivious Pseudorandom Functions +// (VOPRF) and Partially Oblivious Pseudorandom Functions (POPRF) using Elliptic Curve Prime Order Groups (EC-OPRF). +// For OPRF and TOPRF use the github.com/bytemare/oprf package. +package voprf diff --git a/voprf/messages.go b/voprf/messages.go new file mode 100644 index 0000000..42156eb --- /dev/null +++ b/voprf/messages.go @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +// Package voprf implements RFC9497 and provides abstracted access to Oblivious Pseudorandom Functions (OPRF) and +// Threshold Oblivious Pseudorandom Functions (TOPRF) using Elliptic Curve Prime Order Groups (EC-OPRF). +// For VOPRF and POPRF use the github.com/bytemare/oprf/voprf package. +package voprf + +import ( + "encoding/json" + "errors" + "fmt" + + group "github.com/bytemare/crypto" + + "github.com/bytemare/voprf" + "github.com/bytemare/voprf/internal" +) + +var ( + errUnmarshalEvaluationShort = errors.New("decoding error: insufficient data length") + errUnmarshalEvaluationEvals = errors.New("decoding error: wrong encoding length") +) + +// Evaluation is the VOPRF and POPRF servers' output, containing the verifiable proof and evaluated elements. +// To decode a byte string back to an Evaluation, the SetCiphersuite must be used with the relevant ciphersuite. +type Evaluation struct { + // Proof is the NIZK proof over the Evaluations elements. + Proof [2]*group.Scalar `json:"p"` + + // Evaluations is the set of evaluated elements. + Evaluations []*group.Element `json:"e"` + group group.Group +} + +func (e *Evaluation) encodeProof() [2][]byte { + return [2][]byte{ + e.Proof[0].Encode(), + e.Proof[1].Encode(), + } +} + +func (e *Evaluation) encodeEvaluations() []byte { + nEval := len(e.Evaluations) + lenEval := len(e.Evaluations[0].Encode()) + + output := make([]byte, 0, 2+nEval*lenEval) + output = append(output, internal.I2osp2(nEval)...) + + for _, eval := range e.Evaluations { + output = append(output, eval.Encode()...) + } + + return output +} + +// Serialize returns the compact byte encoding of the Evaluation. +func (e *Evaluation) Serialize() []byte { + proof := e.encodeProof() + evaluations := e.encodeEvaluations() + + output := make([]byte, 0, len(proof)+len(evaluations)) + output = append(output, proof[0]...) + output = append(output, proof[1]...) + output = append(output, evaluations...) + + return output +} + +// SetCiphersuite needs to be set by a client on a new Evaluation before decoding it from its compact serialization. +func (e *Evaluation) SetCiphersuite(c voprf.Ciphersuite) { + e.group = c.Group() +} + +func decodeProof(g group.Group, data []byte) ([]*group.Scalar, error) { + sLen := g.ScalarLength() + + pc := g.NewScalar() + if err := pc.Decode(data[:sLen]); err != nil { + return nil, fmt.Errorf("invalid c proof encoding: %w", err) + } + + ps := g.NewScalar() + if err := ps.Decode(data[sLen : 2*sLen]); err != nil { + return nil, fmt.Errorf("invalid s proof encoding: %w", err) + } + + return []*group.Scalar{pc, ps}, nil +} + +func decodeEvaluations(g group.Group, nbEvals int, data []byte) ([]*group.Element, error) { + pLen := g.ElementLength() + i := 0 + evaluations := make([]*group.Element, nbEvals) + + for offset := 0; offset < len(evaluations); offset += pLen { + decoded := g.NewElement() + if err := decoded.Decode(data[offset : offset+pLen]); err != nil { + return nil, fmt.Errorf("invalid evaluation encoding - element %d: %w", i, err) + } + + evaluations[i] = decoded + i++ + } + + return evaluations, nil +} + +// Deserialize decodes a compact serialization of an Evaluation into e. +func (e *Evaluation) Deserialize(data []byte) error { + sLen := e.group.ScalarLength() + pLen := e.group.ElementLength() + + expectedProofLen := 2 * sLen + minimalEvaluationLength := 2 + pLen + + if len(data) < expectedProofLen+minimalEvaluationLength { + return errUnmarshalEvaluationShort + } + + evaluationOffset := expectedProofLen + nbEvals := int(uint16(data[evaluationOffset+1]) | uint16(data[evaluationOffset])<<8) + + evaluations := data[evaluationOffset+2:] + if len(evaluations) != nbEvals*pLen { + return errUnmarshalEvaluationEvals + } + + proof, err := decodeProof(e.group, data[:expectedProofLen]) + if err != nil { + return err + } + + evals, err := decodeEvaluations(e.group, nbEvals, evaluations) + if err != nil { + return err + } + + e.Proof[0] = proof[0] + e.Proof[1] = proof[1] + e.Evaluations = evals + + return nil +} + +// MarshalBinary encodes the Evaluation into its binary form. +func (e *Evaluation) MarshalBinary() ([]byte, error) { + return e.Serialize(), nil +} + +// UnmarshalBinary decodes the binary form of an Evaluation into e. +func (e *Evaluation) UnmarshalBinary(data []byte) error { + return e.Deserialize(data) +} + +// MarshalJSON encodes the Evaluation into JSON. +func (e *Evaluation) MarshalJSON() ([]byte, error) { + enc := struct { + Proof [2][]byte `json:"p"` + Eval []byte `json:"e"` + }{ + Proof: e.encodeProof(), + Eval: e.encodeEvaluations(), + } + + out, err := json.Marshal(enc) + if err != nil { + return nil, fmt.Errorf("encoding evaluation: %w", err) + } + + return out, nil +} + +// UnmarshalJSON decodes a JSON encoded Evaluation into e. +func (e *Evaluation) UnmarshalJSON(data []byte) error { + return e.Deserialize(data) +} diff --git a/voprf/server.go b/voprf/server.go new file mode 100644 index 0000000..aa6554c --- /dev/null +++ b/voprf/server.go @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +package voprf + +import ( + "errors" + + group "github.com/bytemare/crypto" + + "github.com/bytemare/voprf" + "github.com/bytemare/voprf/internal" +) + +// Server is used for VOPRF or POPRF server executions. For OPRF or TOPRF, used the oprf package (no need for a server +// instance). +type Server struct { + // OPRF + *internal.Verifiable + + // VOPRF + privateKey *group.Scalar + publicKey *group.Element + + // POPRF + scalar *group.Scalar + t *group.Scalar + tweakedKey *group.Element +} + +// NewServer returns a server instance given a ciphersuite. poprfInfo must only be provided if +// the POPRF mode is requested. If poprfInfo is not provided or nil, the VOPRF mode is used. +func NewServer(cs voprf.Ciphersuite, poprfInfo ...byte) *Server { + mode := internal.VOPRF + if len(poprfInfo) != 0 { + mode = internal.POPRF + } + + s := &Server{ + Verifiable: internal.NewVerifiable(internal.LoadConfiguration(group.Group(cs), mode), poprfInfo), + privateKey: nil, + publicKey: nil, + scalar: nil, + t: nil, + tweakedKey: nil, + } + + return s +} + +var ( + errInvalidPrivateKey = errors.New("private key is nil or zero") + errInvalidKeyPair = errors.New("input public key doesn't belong to the private key") +) + +func checkKeys(g group.Group, privateKey *group.Scalar, publicKey *group.Element) error { + if publicKey == nil || publicKey.IsIdentity() { + return errInvalidPublicKey + } + + if privateKey == nil || privateKey.IsZero() { + return errInvalidPrivateKey + } + + if g.Base().Multiply(privateKey).Equal(publicKey) != 1 { + return errInvalidKeyPair + } + + return nil +} + +func (s *Server) setKeyPair(privateKey *group.Scalar, publicKey *group.Element) { + s.privateKey = privateKey + s.publicKey = publicKey + + if s.Core.Mode == internal.POPRF { + s.scalar, s.t = s.Verifiable.TweakPrivateKey(privateKey) + s.tweakedKey = s.Core.Group.Base().Multiply(s.t) + } else { + s.scalar = s.privateKey + } +} + +// SetKeyPair sets the server's private and public key pair. This returns an error if either key is nil, the public key +// is the identity element, or if it doesn't match as a public key to the provided private key. +func (s *Server) SetKeyPair(privateKey *group.Scalar, publicKey *group.Element) error { + if err := checkKeys(s.Core.Group, privateKey, publicKey); err != nil { + return err + } + + s.setKeyPair(privateKey, publicKey) + + return nil +} + +// DeriveKeyPair derives and set the server's private and public key pair given a secret seed and instance specific +// info. +func (s *Server) DeriveKeyPair(seed, info []byte) { + sk, pk := s.Core.DeriveKeyPair(seed, info) + s.setKeyPair(sk, pk) +} + +// GenerateKeys generates and sets a new, random private and public key pair. +func (s *Server) GenerateKeys() { + sk := s.Core.Group.NewScalar().Random() + pk := s.Core.Group.Base().Multiply(sk) + s.setKeyPair(sk, pk) +} + +// KeyPair returns the server's private and public key pair. +func (s *Server) KeyPair() (*group.Scalar, *group.Element) { + return s.privateKey, s.publicKey +} + +func (s *Server) evaluate( + blinded []*group.Element, + random []*group.Scalar, +) *Evaluation { + // Set the random element for the proof + r := s.Group.NewScalar() + if len(random) != 0 && random[0] != nil { + r = random[0] + } else { + r.Random() + } + + // Evaluate + evaluated := voprf.EvaluateBatch(s.scalar, blinded) + + var proofC, proofS *group.Scalar + + if s.Core.Mode == internal.VOPRF { + proofC, proofS = s.Verifiable.GenerateProof(r, s.privateKey, s.publicKey, blinded, evaluated) + } else { // POPRF + proofC, proofS = s.Verifiable.GenerateProof(r, s.t, s.tweakedKey, evaluated, blinded) + } + + return &Evaluation{ + group: s.Group, + Proof: [2]*group.Scalar{ + proofC, proofS, + }, + Evaluations: evaluated, + } +} + +// Evaluate takes the Client provided blinded element and evaluates it, returning the evaluated element and the +// NIZK proof. The random argument is optional, and enables to force the use of that scalar for the random input to the +// NIZK proof. +func (s *Server) Evaluate( + blinded *group.Element, + random ...*group.Scalar, +) *Evaluation { + sBlinded := []*group.Element{blinded} + return s.evaluate(sBlinded, random) +} + +// EvaluateBatch takes the Client provided blinded elements and evaluates them, returning the evaluated elements and the +// unique NIZK proof for the whole set. The random argument is optional, and enables to force the use of that scalar for +// the random input to the NIZK proof. +func (s *Server) EvaluateBatch( + blinded []*group.Element, + random ...*group.Scalar, +) *Evaluation { + return s.evaluate(blinded, random) +}