Skip to content
This repository has been archived by the owner on Oct 3, 2024. It is now read-only.

Commit

Permalink
Add Scalar.UInt64() to return a uint64 from the scalar (#58)
Browse files Browse the repository at this point in the history
* Add Scalar.UInt64() to return a uint64 from the scalar

Signed-off-by: bytemare <[email protected]>

---------

Signed-off-by: bytemare <[email protected]>
  • Loading branch information
bytemare authored Jun 13, 2024
1 parent cd6020a commit 3bec1a4
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type Scalar interface {
IsZero() bool
Set(Scalar) Scalar
SetUInt64(uint64) Scalar
UInt64() (uint64, error)
Copy() Scalar
Encode() []byte
Decode(in []byte) error
Expand Down
17 changes: 17 additions & 0 deletions internal/edwards25519/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,23 @@ func (s *Scalar) SetUInt64(i uint64) internal.Scalar {
return s
}

// UInt64 returns the uint64 representation of the scalar,
// or an error if its value is higher than the authorized limit for uint64.
func (s *Scalar) UInt64() (uint64, error) {
b := s.scalar.Bytes()
overflows := byte(0)

for _, bx := range b[8:] {
overflows |= bx
}

if overflows != 0 {
return 0, internal.ErrUInt64TooBig
}

return binary.LittleEndian.Uint64(b[:8]), nil
}

func (s *Scalar) copy() *Scalar {
return &Scalar{*ed.NewScalar().Set(&s.scalar)}
}
Expand Down
3 changes: 3 additions & 0 deletions internal/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ var (

// ErrParamScalarInvalidEncoding indicates an invalid scalar encoding has been provided, or that it's too big.
ErrParamScalarInvalidEncoding = errors.New("invalid scalar encoding")

// ErrUInt64TooBig indicates that the scalar is higher than the allowed values for uint64.
ErrUInt64TooBig = errors.New("scalar is too big to be uint64")
)

// An Encoder can encode itself to machine or human-readable forms.
Expand Down
19 changes: 19 additions & 0 deletions internal/nist/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package nist

import (
"crypto/subtle"
"encoding/binary"
"encoding/hex"
"fmt"
"math/big"
Expand Down Expand Up @@ -188,6 +189,24 @@ func (s *Scalar) SetUInt64(i uint64) internal.Scalar {
return s
}

// UInt64 returns the uint64 representation of the scalar,
// or an error if its value is higher than the authorized limit for uint64.
func (s *Scalar) UInt64() (uint64, error) {
b := s.Encode()
overflows := byte(0)
scalarLength := (s.field.BitLen() + 7) / 8

for _, bx := range b[:scalarLength-8] {
overflows |= bx
}

if overflows != 0 {
return 0, internal.ErrUInt64TooBig
}

return binary.BigEndian.Uint64(b[scalarLength-8:]), nil
}

// Copy returns a copy of the Scalar.
func (s *Scalar) Copy() internal.Scalar {
cpy := newScalar(s.field)
Expand Down
17 changes: 17 additions & 0 deletions internal/ristretto/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,23 @@ func (s *Scalar) SetUInt64(i uint64) internal.Scalar {
return s
}

// UInt64 returns the uint64 representation of the scalar,
// or an error if its value is higher than the authorized limit for uint64.
func (s *Scalar) UInt64() (uint64, error) {
b := s.scalar.Encode(nil)
overflows := byte(0)

for _, bx := range b[8:] {
overflows |= bx
}

if overflows != 0 {
return 0, internal.ErrUInt64TooBig
}

return binary.LittleEndian.Uint64(b[:8]), nil
}

func (s *Scalar) copy() *Scalar {
return &Scalar{*ristretto255.NewScalar().Add(ristretto255.NewScalar(), &s.scalar)}
}
Expand Down
4 changes: 4 additions & 0 deletions internal/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ type Scalar interface {
// SetUInt64 sets s to i modulo the field order, and returns an error if one occurs.
SetUInt64(i uint64) Scalar

// UInt64 returns the uint64 representation of the scalar,
// or an error if its value is higher than the authorized limit for uint64.
UInt64() (uint64, error)

// Copy returns a copy of the receiver.
Copy() Scalar

Expand Down
18 changes: 18 additions & 0 deletions internal/secp256k1/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package secp256k1

import (
"encoding/binary"
"fmt"

"github.com/bytemare/secp256k1"
Expand Down Expand Up @@ -151,6 +152,23 @@ func (s *Scalar) SetUInt64(i uint64) internal.Scalar {
return s
}

// UInt64 returns the uint64 representation of the scalar,
// or an error if its value is higher than the authorized limit for uint64.
func (s *Scalar) UInt64() (uint64, error) {
b := s.scalar.Encode()
overflows := byte(0)

for _, bx := range b[:scalarLength-8] {
overflows |= bx
}

if overflows != 0 {
return 0, internal.ErrUInt64TooBig
}

return binary.BigEndian.Uint64(b[scalarLength-8:]), nil
}

// Copy returns a copy of the receiver.
func (s *Scalar) Copy() internal.Scalar {
return &Scalar{scalar: s.scalar.Copy()}
Expand Down
11 changes: 11 additions & 0 deletions scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ func (s *Scalar) SetUInt64(i uint64) *Scalar {
return s
}

// UInt64 returns the uint64 representation of the scalar,
// or an error if its value is higher than the authorized limit for uint64.
func (s *Scalar) UInt64() (uint64, error) {
i, err := s.Scalar.UInt64()
if err != nil {
return 0, fmt.Errorf("%w", err)
}

return i, nil
}

// Copy returns a copy of the receiver.
func (s *Scalar) Copy() *Scalar {
return &Scalar{Scalar: s.Scalar.Copy()}
Expand Down
52 changes: 52 additions & 0 deletions tests/scalar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,58 @@ func TestScalarSet(t *testing.T) {
})
}

func parseScalar(s *crypto.Scalar) ([]byte, bool) {
b := s.Encode()
b3 := b[8:]
b4 := byte(0)
for _, bx := range b3 {
b4 |= bx
}
return b[:8], b4 == 0
}

func testScalarUInt64(t *testing.T, s *crypto.Scalar, expectedValue uint64, expectedError error) {
i, err := s.UInt64()

if err == nil {
if expectedError != nil {
t.Fatalf("expected error %q", expectedError)
}
} else {
if expectedError == nil {
t.Fatalf("unexpected error %q", err)
} else if err.Error() != expectedError.Error() {
t.Fatalf("expected error %q, got %q", expectedError, err)
}
}

if expectedError == nil && i != expectedValue {
t.Fatalf("expected %d, got %d", expectedValue, i)
}
}

func TestScalar_UInt64(t *testing.T) {
expectedError := errors.New("scalar is too big to be uint64")
testAll(t, func(group *testGroup) {
// 0
testScalarUInt64(t, group.group.NewScalar(), 0, nil)

// 1
testScalarUInt64(t, group.group.NewScalar().One(), 1, nil)

// Max Uint64
testScalarUInt64(t, group.group.NewScalar().SetUInt64(math.MaxUint64), math.MaxUint64, nil)

// Max Uint64+1 fails
s := group.group.NewScalar().SetUInt64(math.MaxUint64).Add(group.group.NewScalar().One())
testScalarUInt64(t, s, 0, expectedError)

// Order - 1 fails
s = group.group.NewScalar().Subtract(group.group.NewScalar().One())
testScalarUInt64(t, s, 0, expectedError)
})
}

func TestScalar_SetUInt64(t *testing.T) {
testAll(t, func(group *testGroup) {
s := group.group.NewScalar().SetUInt64(0)
Expand Down

0 comments on commit 3bec1a4

Please sign in to comment.