Skip to content
This repository has been archived by the owner on Oct 3, 2024. It is now read-only.

Commit

Permalink
Equal() now returns a boolean (#67)
Browse files Browse the repository at this point in the history
* Equal() now returns a boolean

---------

Signed-off-by: bytemare <[email protected]>
  • Loading branch information
bytemare authored Oct 1, 2024
1 parent e89a42e commit 5fb54db
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 74 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go: [ '1.23.2', '1.22.8', '1.21.13' ]
go: [ '1.23', '1.22', '1.21' ]
uses: bytemare/workflows/.github/workflows/test-go.yml@f572ea606a74fe011e68a23c19f8d4f5daf58488
with:
command: cd .github && make test
Expand Down
8 changes: 4 additions & 4 deletions element.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ func (e *Element) Multiply(scalar *Scalar) *Element {
return e
}

// Equal returns 1 if the elements are equivalent, and 0 otherwise.
func (e *Element) Equal(element *Element) int {
// Equal returns true if the elements are equivalent, and false otherwise.
func (e *Element) Equal(element *Element) bool {
if element == nil {
return 0
return false
}

return e.Element.Equal(element.Element)
return e.Element.Equal(element.Element) == 1
}

// IsIdentity returns whether the Element is the point at infinity of the Group's underlying curve.
Expand Down
8 changes: 4 additions & 4 deletions scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ func (s *Scalar) Invert() *Scalar {
return s
}

// Equal returns 1 if the scalars are equal, and 0 otherwise.
func (s *Scalar) Equal(scalar *Scalar) int {
// Equal returns true if the elements are equivalent, and false otherwise.
func (s *Scalar) Equal(scalar *Scalar) bool {
if scalar == nil {
return 0
return false
}

return s.Scalar.Equal(scalar.Scalar)
return s.Scalar.Equal(scalar.Scalar) == 1
}

// LessOrEqual returns 1 if s <= scalar, and 0 otherwise.
Expand Down
64 changes: 32 additions & 32 deletions tests/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ func testElementCopySet(t *testing.T, element, other *crypto.Element) {
}

// Verify whether they are equivalent
if element.Equal(other) != 1 {
if !element.Equal(other) {
t.Fatalf("Expected equality")
}

// Verify than operations on one don't affect the other
element.Add(element)
if element.Equal(other) == 1 {
if element.Equal(other) {
t.Fatalf(errUnExpectedEquality)
}

other.Double().Double()
if element.Equal(other) == 1 {
if element.Equal(other) {
t.Fatalf(errUnExpectedEquality)
}

// Verify setting to nil sets to identity
if element.Set(nil).Equal(other.Identity()) != 1 {
if !element.Set(nil).Equal(other.Identity()) {
t.Error(errExpectedEquality)
}
}
Expand Down Expand Up @@ -87,7 +87,7 @@ func TestElement_WrongInput(t *testing.T) {
}
}

equal := func(f func(*crypto.Element) int, arg *crypto.Element) func() {
equal := func(f func(*crypto.Element) bool, arg *crypto.Element) func() {
return func() {
f(arg)
}
Expand Down Expand Up @@ -254,19 +254,19 @@ func TestElement_Vectors_Add(t *testing.T) {

for _, mult := range group.multBase {
e := decodeElement(t, group.group, mult)
if e.Equal(acc) != 1 {
if !e.Equal(acc) {
t.Fatal("expected equality")
}

acc.Add(base)
}

base.Add(group.group.NewElement())
if base.Equal(group.group.Base()) != 1 {
if !base.Equal(group.group.Base()) {
t.Fatal(errExpectedEquality)
}

if group.group.NewElement().Add(base).Equal(base) != 1 {
if !group.group.NewElement().Add(base).Equal(base) {
t.Fatal(errExpectedEquality)
}
})
Expand All @@ -287,7 +287,7 @@ func TestElement_Vectors_Double(t *testing.T) {
e.Double()

v := decodeElement(t, group.group, group.multBase[multiple-1])
if v.Equal(e) != 1 {
if !v.Equal(e) {
t.Fatalf("expected equality for %d", multiple)
}
}
Expand All @@ -302,7 +302,7 @@ func TestElement_Vectors_Mult(t *testing.T) {

for i, mult := range group.multBase {
e := decodeElement(t, group.group, mult)
if e.Equal(base) != 1 {
if !e.Equal(base) {
t.Fatalf("expected equality for %d", i)
}

Expand All @@ -328,17 +328,17 @@ func elementTestEqual(t *testing.T, g crypto.Group) {
base := g.Base()
base2 := g.Base()

if base.Equal(nil) != 0 {
if base.Equal(nil) {
t.Fatal(errUnExpectedEquality)
}

if base.Equal(base2) != 1 {
if !base.Equal(base2) {
t.Fatal(errExpectedEquality)
}

random := g.NewElement().Multiply(g.NewScalar().Random())
cpy := random.Copy()
if random.Equal(cpy) != 1 {
if !random.Equal(cpy) {
t.Fatal()
}
}
Expand All @@ -347,37 +347,37 @@ func elementTestAdd(t *testing.T, g crypto.Group) {
// Verify whether add yields the same element when given nil
base := g.Base()
cpy := base.Copy()
if cpy.Add(nil).Equal(base) != 1 {
if !cpy.Add(nil).Equal(base) {
t.Fatal(errExpectedEquality)
}

// Verify whether add yields the same element when given identity
base = g.Base()
cpy = base.Copy()
cpy.Add(g.NewElement())
if cpy.Equal(base) != 1 {
if !cpy.Equal(base) {
t.Fatal(errExpectedEquality)
}

// Verify whether add yields the same when adding to identity
base = g.Base()
identity := g.NewElement()
if identity.Add(base).Equal(base) != 1 {
if !identity.Add(base).Equal(base) {
t.Fatal(errExpectedEquality)
}

// Verify whether add yields the identity given the negative
base = g.Base()
negative := g.Base().Negate()
identity = g.NewElement()
if base.Add(negative).Equal(identity) != 1 {
if !base.Add(negative).Equal(identity) {
t.Fatal(errExpectedEquality)
}

// Verify whether add yields the double when adding to itself
base = g.Base()
double := g.Base().Double()
if base.Add(base).Equal(double) != 1 {
if !base.Add(base).Equal(double) {
t.Fatal(errExpectedEquality)
}

Expand All @@ -389,7 +389,7 @@ func elementTestAdd(t *testing.T, g crypto.Group) {
mult := g.Base().Multiply(three)
e := g.Base().Add(g.Base()).Add(g.Base())

if e.Equal(mult) != 1 {
if !e.Equal(mult) {
t.Fatal(errExpectedEquality)
}
}
Expand All @@ -399,7 +399,7 @@ func elementTestNegate(t *testing.T, g crypto.Group) {
id := g.NewElement().Identity()
negId := g.NewElement().Identity().Negate()

if id.Equal(negId) != 1 {
if !id.Equal(negId) {
t.Fatal("expected equality when negating identity element")
}

Expand All @@ -416,7 +416,7 @@ func elementTestNegate(t *testing.T, g crypto.Group) {
b = g.NewElement().Base()
negB = g.NewElement().Base().Negate().Negate()

if b.Equal(negB) != 1 {
if !b.Equal(negB) {
t.Fatal("expected equality -(-b) = b")
}
}
Expand All @@ -425,13 +425,13 @@ func elementTestDouble(t *testing.T, g crypto.Group) {
// Verify whether double works like adding
base := g.Base()
double := g.Base().Add(g.Base())
if double.Equal(base.Double()) != 1 {
if !double.Equal(base.Double()) {
t.Fatal(errExpectedEquality)
}

two := g.NewScalar().One().Add(g.NewScalar().One())
mult := g.Base().Multiply(two)
if mult.Equal(double) != 1 {
if !mult.Equal(double) {
t.Fatal(errExpectedEquality)
}
}
Expand All @@ -440,13 +440,13 @@ func elementTestSubstract(t *testing.T, g crypto.Group) {
base := g.Base()

// Verify whether subtracting yields the same element when given nil.
if base.Subtract(nil).Equal(base) != 1 {
if !base.Subtract(nil).Equal(base) {
t.Fatal(errExpectedEquality)
}

// Verify whether subtracting and then adding yields the same element.
base2 := base.Add(base).Subtract(base)
if base.Equal(base2) != 1 {
if !base.Equal(base2) {
t.Fatal(errExpectedEquality)
}
}
Expand All @@ -457,7 +457,7 @@ func elementTestMultiply(t *testing.T, g crypto.Group) {
// base = base * 1
base := g.Base()
mult := g.Base().Multiply(scalar.One())
if base.Equal(mult) != 1 {
if !base.Equal(mult) {
t.Fatal(errExpectedEquality)
}

Expand All @@ -473,7 +473,7 @@ func elementTestMultiply(t *testing.T, g crypto.Group) {
two := g.NewScalar().One().Add(g.NewScalar().One())
mult = g.Base().Multiply(two)

if mult.Equal(twoG) != 1 {
if !mult.Equal(twoG) {
t.Fatal(errExpectedEquality)
}

Expand All @@ -495,30 +495,30 @@ func elementTestIdentity(t *testing.T, g crypto.Group) {
}

base := g.Base()
if id.Equal(base.Subtract(base)) != 1 {
if !id.Equal(base.Subtract(base)) {
log.Printf("id : %v", id.Encode())
log.Printf("ba : %v", base.Encode())
t.Fatal(errExpectedIdentity)
}

sub1 := g.Base().Double().Negate().Add(g.Base().Double())
sub2 := g.Base().Subtract(g.Base())
if sub1.Equal(sub2) != 1 {
if !sub1.Equal(sub2) {
t.Fatal(errExpectedEquality)
}

if id.Equal(base.Multiply(nil)) != 1 {
if !id.Equal(base.Multiply(nil)) {
t.Fatal(errExpectedIdentity)
}

if id.Equal(base.Multiply(g.NewScalar().Zero())) != 1 {
if !id.Equal(base.Multiply(g.NewScalar().Zero())) {
t.Fatal(errExpectedIdentity)
}

base = g.Base()
neg := base.Copy().Negate()
base.Add(neg)
if id.Equal(base) != 1 {
if !id.Equal(base) {
t.Fatal(errExpectedIdentity)
}
}
8 changes: 4 additions & 4 deletions tests/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func testScalarEncodings(g crypto.Group, f makeEncodeTest) error {
return err
}

if source.Equal(receiver) != 1 {
if !source.Equal(receiver) {
return errors.New(errExpectedEquality)
}

Expand All @@ -152,7 +152,7 @@ func testElementEncodings(g crypto.Group, f makeEncodeTest) error {
return err
}

if source.Equal(receiver) != 1 {
if !source.Equal(receiver) {
return errors.New(errExpectedEquality)
}

Expand Down Expand Up @@ -248,7 +248,7 @@ func TestEncoding_Hex_Fails(t *testing.T) {
t.Fatalf("unexpected error on valid encoding: %s", err)
}

if s.Equal(scalar) != 1 {
if !s.Equal(scalar) {
t.Fatal(errExpectedEquality)
}

Expand All @@ -258,7 +258,7 @@ func TestEncoding_Hex_Fails(t *testing.T) {
t.Fatalf("unexpected error on valid encoding: %s", err)
}

if e.Equal(element) != 1 {
if !e.Equal(element) {
t.Fatal(errExpectedEquality)
}
})
Expand Down
4 changes: 2 additions & 2 deletions tests/groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestHashToScalar(t *testing.T) {
sv := decodeScalar(t, group.group, group.hashToCurve.hashToScalar)

s := group.group.HashToScalar(group.hashToCurve.input, group.hashToCurve.dst)
if s.Equal(sv) != 1 {
if !s.Equal(sv) {
t.Error(errExpectedEquality)
}
})
Expand Down Expand Up @@ -185,7 +185,7 @@ func TestHashToGroup(t *testing.T) {
ev := decodeElement(t, group.group, group.hashToCurve.hashToGroup)

e := group.group.HashToGroup(group.hashToCurve.input, group.hashToCurve.dst)
if e.Equal(ev) != 1 {
if !e.Equal(ev) {
t.Error(errExpectedEquality)
}
})
Expand Down
Loading

0 comments on commit 5fb54db

Please sign in to comment.