Skip to content
This repository has been archived by the owner on Oct 3, 2024. It is now read-only.

Equal() now returns a boolean #67

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go: [ '1.23.2', '1.22.8', '1.21.13' ]
go: [ '1.23', '1.22', '1.21' ]
uses: bytemare/workflows/.github/workflows/test-go.yml@f572ea606a74fe011e68a23c19f8d4f5daf58488
with:
command: cd .github && make test
Expand Down
8 changes: 4 additions & 4 deletions element.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ func (e *Element) Multiply(scalar *Scalar) *Element {
return e
}

// Equal returns 1 if the elements are equivalent, and 0 otherwise.
func (e *Element) Equal(element *Element) int {
// Equal returns true if the elements are equivalent, and false otherwise.
func (e *Element) Equal(element *Element) bool {
if element == nil {
return 0
return false
}

return e.Element.Equal(element.Element)
return e.Element.Equal(element.Element) == 1
}

// IsIdentity returns whether the Element is the point at infinity of the Group's underlying curve.
Expand Down
8 changes: 4 additions & 4 deletions scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ func (s *Scalar) Invert() *Scalar {
return s
}

// Equal returns 1 if the scalars are equal, and 0 otherwise.
func (s *Scalar) Equal(scalar *Scalar) int {
// Equal returns true if the elements are equivalent, and false otherwise.
func (s *Scalar) Equal(scalar *Scalar) bool {
if scalar == nil {
return 0
return false
}

return s.Scalar.Equal(scalar.Scalar)
return s.Scalar.Equal(scalar.Scalar) == 1
}

// LessOrEqual returns 1 if s <= scalar, and 0 otherwise.
Expand Down
64 changes: 32 additions & 32 deletions tests/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ func testElementCopySet(t *testing.T, element, other *crypto.Element) {
}

// Verify whether they are equivalent
if element.Equal(other) != 1 {
if !element.Equal(other) {
t.Fatalf("Expected equality")
}

// Verify than operations on one don't affect the other
element.Add(element)
if element.Equal(other) == 1 {
if element.Equal(other) {
t.Fatalf(errUnExpectedEquality)
}

other.Double().Double()
if element.Equal(other) == 1 {
if element.Equal(other) {
t.Fatalf(errUnExpectedEquality)
}

// Verify setting to nil sets to identity
if element.Set(nil).Equal(other.Identity()) != 1 {
if !element.Set(nil).Equal(other.Identity()) {
t.Error(errExpectedEquality)
}
}
Expand Down Expand Up @@ -87,7 +87,7 @@ func TestElement_WrongInput(t *testing.T) {
}
}

equal := func(f func(*crypto.Element) int, arg *crypto.Element) func() {
equal := func(f func(*crypto.Element) bool, arg *crypto.Element) func() {
return func() {
f(arg)
}
Expand Down Expand Up @@ -254,19 +254,19 @@ func TestElement_Vectors_Add(t *testing.T) {

for _, mult := range group.multBase {
e := decodeElement(t, group.group, mult)
if e.Equal(acc) != 1 {
if !e.Equal(acc) {
t.Fatal("expected equality")
}

acc.Add(base)
}

base.Add(group.group.NewElement())
if base.Equal(group.group.Base()) != 1 {
if !base.Equal(group.group.Base()) {
t.Fatal(errExpectedEquality)
}

if group.group.NewElement().Add(base).Equal(base) != 1 {
if !group.group.NewElement().Add(base).Equal(base) {
t.Fatal(errExpectedEquality)
}
})
Expand All @@ -287,7 +287,7 @@ func TestElement_Vectors_Double(t *testing.T) {
e.Double()

v := decodeElement(t, group.group, group.multBase[multiple-1])
if v.Equal(e) != 1 {
if !v.Equal(e) {
t.Fatalf("expected equality for %d", multiple)
}
}
Expand All @@ -302,7 +302,7 @@ func TestElement_Vectors_Mult(t *testing.T) {

for i, mult := range group.multBase {
e := decodeElement(t, group.group, mult)
if e.Equal(base) != 1 {
if !e.Equal(base) {
t.Fatalf("expected equality for %d", i)
}

Expand All @@ -328,17 +328,17 @@ func elementTestEqual(t *testing.T, g crypto.Group) {
base := g.Base()
base2 := g.Base()

if base.Equal(nil) != 0 {
if base.Equal(nil) {
t.Fatal(errUnExpectedEquality)
}

if base.Equal(base2) != 1 {
if !base.Equal(base2) {
t.Fatal(errExpectedEquality)
}

random := g.NewElement().Multiply(g.NewScalar().Random())
cpy := random.Copy()
if random.Equal(cpy) != 1 {
if !random.Equal(cpy) {
t.Fatal()
}
}
Expand All @@ -347,37 +347,37 @@ func elementTestAdd(t *testing.T, g crypto.Group) {
// Verify whether add yields the same element when given nil
base := g.Base()
cpy := base.Copy()
if cpy.Add(nil).Equal(base) != 1 {
if !cpy.Add(nil).Equal(base) {
t.Fatal(errExpectedEquality)
}

// Verify whether add yields the same element when given identity
base = g.Base()
cpy = base.Copy()
cpy.Add(g.NewElement())
if cpy.Equal(base) != 1 {
if !cpy.Equal(base) {
t.Fatal(errExpectedEquality)
}

// Verify whether add yields the same when adding to identity
base = g.Base()
identity := g.NewElement()
if identity.Add(base).Equal(base) != 1 {
if !identity.Add(base).Equal(base) {
t.Fatal(errExpectedEquality)
}

// Verify whether add yields the identity given the negative
base = g.Base()
negative := g.Base().Negate()
identity = g.NewElement()
if base.Add(negative).Equal(identity) != 1 {
if !base.Add(negative).Equal(identity) {
t.Fatal(errExpectedEquality)
}

// Verify whether add yields the double when adding to itself
base = g.Base()
double := g.Base().Double()
if base.Add(base).Equal(double) != 1 {
if !base.Add(base).Equal(double) {
t.Fatal(errExpectedEquality)
}

Expand All @@ -389,7 +389,7 @@ func elementTestAdd(t *testing.T, g crypto.Group) {
mult := g.Base().Multiply(three)
e := g.Base().Add(g.Base()).Add(g.Base())

if e.Equal(mult) != 1 {
if !e.Equal(mult) {
t.Fatal(errExpectedEquality)
}
}
Expand All @@ -399,7 +399,7 @@ func elementTestNegate(t *testing.T, g crypto.Group) {
id := g.NewElement().Identity()
negId := g.NewElement().Identity().Negate()

if id.Equal(negId) != 1 {
if !id.Equal(negId) {
t.Fatal("expected equality when negating identity element")
}

Expand All @@ -416,7 +416,7 @@ func elementTestNegate(t *testing.T, g crypto.Group) {
b = g.NewElement().Base()
negB = g.NewElement().Base().Negate().Negate()

if b.Equal(negB) != 1 {
if !b.Equal(negB) {
t.Fatal("expected equality -(-b) = b")
}
}
Expand All @@ -425,13 +425,13 @@ func elementTestDouble(t *testing.T, g crypto.Group) {
// Verify whether double works like adding
base := g.Base()
double := g.Base().Add(g.Base())
if double.Equal(base.Double()) != 1 {
if !double.Equal(base.Double()) {
t.Fatal(errExpectedEquality)
}

two := g.NewScalar().One().Add(g.NewScalar().One())
mult := g.Base().Multiply(two)
if mult.Equal(double) != 1 {
if !mult.Equal(double) {
t.Fatal(errExpectedEquality)
}
}
Expand All @@ -440,13 +440,13 @@ func elementTestSubstract(t *testing.T, g crypto.Group) {
base := g.Base()

// Verify whether subtracting yields the same element when given nil.
if base.Subtract(nil).Equal(base) != 1 {
if !base.Subtract(nil).Equal(base) {
t.Fatal(errExpectedEquality)
}

// Verify whether subtracting and then adding yields the same element.
base2 := base.Add(base).Subtract(base)
if base.Equal(base2) != 1 {
if !base.Equal(base2) {
t.Fatal(errExpectedEquality)
}
}
Expand All @@ -457,7 +457,7 @@ func elementTestMultiply(t *testing.T, g crypto.Group) {
// base = base * 1
base := g.Base()
mult := g.Base().Multiply(scalar.One())
if base.Equal(mult) != 1 {
if !base.Equal(mult) {
t.Fatal(errExpectedEquality)
}

Expand All @@ -473,7 +473,7 @@ func elementTestMultiply(t *testing.T, g crypto.Group) {
two := g.NewScalar().One().Add(g.NewScalar().One())
mult = g.Base().Multiply(two)

if mult.Equal(twoG) != 1 {
if !mult.Equal(twoG) {
t.Fatal(errExpectedEquality)
}

Expand All @@ -495,30 +495,30 @@ func elementTestIdentity(t *testing.T, g crypto.Group) {
}

base := g.Base()
if id.Equal(base.Subtract(base)) != 1 {
if !id.Equal(base.Subtract(base)) {
log.Printf("id : %v", id.Encode())
log.Printf("ba : %v", base.Encode())
t.Fatal(errExpectedIdentity)
}

sub1 := g.Base().Double().Negate().Add(g.Base().Double())
sub2 := g.Base().Subtract(g.Base())
if sub1.Equal(sub2) != 1 {
if !sub1.Equal(sub2) {
t.Fatal(errExpectedEquality)
}

if id.Equal(base.Multiply(nil)) != 1 {
if !id.Equal(base.Multiply(nil)) {
t.Fatal(errExpectedIdentity)
}

if id.Equal(base.Multiply(g.NewScalar().Zero())) != 1 {
if !id.Equal(base.Multiply(g.NewScalar().Zero())) {
t.Fatal(errExpectedIdentity)
}

base = g.Base()
neg := base.Copy().Negate()
base.Add(neg)
if id.Equal(base) != 1 {
if !id.Equal(base) {
t.Fatal(errExpectedIdentity)
}
}
8 changes: 4 additions & 4 deletions tests/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func testScalarEncodings(g crypto.Group, f makeEncodeTest) error {
return err
}

if source.Equal(receiver) != 1 {
if !source.Equal(receiver) {
return errors.New(errExpectedEquality)
}

Expand All @@ -152,7 +152,7 @@ func testElementEncodings(g crypto.Group, f makeEncodeTest) error {
return err
}

if source.Equal(receiver) != 1 {
if !source.Equal(receiver) {
return errors.New(errExpectedEquality)
}

Expand Down Expand Up @@ -248,7 +248,7 @@ func TestEncoding_Hex_Fails(t *testing.T) {
t.Fatalf("unexpected error on valid encoding: %s", err)
}

if s.Equal(scalar) != 1 {
if !s.Equal(scalar) {
t.Fatal(errExpectedEquality)
}

Expand All @@ -258,7 +258,7 @@ func TestEncoding_Hex_Fails(t *testing.T) {
t.Fatalf("unexpected error on valid encoding: %s", err)
}

if e.Equal(element) != 1 {
if !e.Equal(element) {
t.Fatal(errExpectedEquality)
}
})
Expand Down
4 changes: 2 additions & 2 deletions tests/groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestHashToScalar(t *testing.T) {
sv := decodeScalar(t, group.group, group.hashToCurve.hashToScalar)

s := group.group.HashToScalar(group.hashToCurve.input, group.hashToCurve.dst)
if s.Equal(sv) != 1 {
if !s.Equal(sv) {
t.Error(errExpectedEquality)
}
})
Expand Down Expand Up @@ -185,7 +185,7 @@ func TestHashToGroup(t *testing.T) {
ev := decodeElement(t, group.group, group.hashToCurve.hashToGroup)

e := group.group.HashToGroup(group.hashToCurve.input, group.hashToCurve.dst)
if e.Equal(ev) != 1 {
if !e.Equal(ev) {
t.Error(errExpectedEquality)
}
})
Expand Down
Loading
Loading