diff --git a/element.go b/element.go index c22b5ae..b1cdc61 100644 --- a/element.go +++ b/element.go @@ -146,6 +146,16 @@ func (e *Element) DecodeHex(h string) error { return nil } +// MarshalJSON marshals the element into valid JSON. +func (e *Element) MarshalJSON() ([]byte, error) { + return e.Encode(), nil +} + +// UnmarshalJSON unmarshals the input into the element. +func (e *Element) UnmarshalJSON(data []byte) error { + return e.Decode(data) +} + // MarshalBinary returns the compressed byte encoding of the element. func (e *Element) MarshalBinary() ([]byte, error) { dec, err := e.Element.MarshalBinary() diff --git a/scalar.go b/scalar.go index fe3e621..e2b7316 100644 --- a/scalar.go +++ b/scalar.go @@ -178,6 +178,16 @@ func (s *Scalar) DecodeHex(h string) error { return nil } +// MarshalJSON marshals the scalar into valid JSON. +func (s *Scalar) MarshalJSON() ([]byte, error) { + return s.Encode(), nil +} + +// UnmarshalJSON unmarshals the input into the scalar. +func (s *Scalar) UnmarshalJSON(data []byte) error { + return s.Decode(data) +} + // MarshalBinary implements the encoding.BinaryMarshaler interface. func (s *Scalar) MarshalBinary() ([]byte, error) { dec, err := s.Scalar.MarshalBinary() diff --git a/tests/encoding_test.go b/tests/encoding_test.go new file mode 100644 index 0000000..3e98774 --- /dev/null +++ b/tests/encoding_test.go @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +package group_test + +import ( + "bytes" + "encoding" + "encoding/hex" + "strings" + "testing" +) + +type serde interface { + Encode() []byte + Decode(data []byte) error + MarshalJSON() ([]byte, error) + UnmarshalJSON(data []byte) error + Hex() string + DecodeHex(h string) error + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler +} + +func testEncoding(t *testing.T, thing1, thing2 serde) { + // empty string + if err := thing2.DecodeHex(""); err == nil { + t.Fatal("expected error on empty string") + } + + encoded := thing1.Encode() + marshalled, _ := thing1.MarshalBinary() + hexed := thing1.Hex() + + jsoned, err := thing1.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(encoded, marshalled) { + t.Fatalf("Encode() and MarshalBinary() are expected to have the same output."+ + "\twant: %v\tgot : %v", encoded, marshalled) + } + + if hex.EncodeToString(encoded) != hexed { + t.Fatalf("Failed hex encoding, want %q, got %q", hex.EncodeToString(encoded), hexed) + } + + if err := thing2.Decode(nil); err == nil { + t.Fatal("expected error on Decode() with nil input") + } + + if err := thing2.Decode(encoded); err != nil { + t.Fatalf("Decode() failed on a valid encoding: %v. Value: %v", err, hex.EncodeToString(encoded)) + } + + if err := thing2.UnmarshalJSON(jsoned); err != nil { + t.Fatalf("UnmarshalJSON() failed on a valid encoding: %v", err) + } + + if err := thing2.UnmarshalBinary(encoded); err != nil { + t.Fatalf("UnmarshalBinary() failed on a valid encoding: %v", err) + } + + if err := thing2.DecodeHex(hexed); err != nil { + t.Fatalf("DecodeHex() failed on valid hex encoding: %v", err) + } +} + +func TestEncoding(t *testing.T) { + testAll(t, func(group *testGroup) { + g := group.group + scalar := g.NewScalar().Random() + testEncoding(t, scalar, g.NewScalar()) + + scalar = g.NewScalar().Random() + element := g.Base().Multiply(scalar) + testEncoding(t, element, g.NewElement()) + }) +} + +func testDecodingHexFails(t *testing.T, thing1, thing2 serde) { + // empty string + if err := thing2.DecodeHex(""); err == nil { + t.Fatal("expected error on empty string") + } + + // malformed string + hexed := thing1.Hex() + malformed := []rune(hexed) + malformed[0] = []rune("_")[0] + + if err := thing2.DecodeHex(string(malformed)); err == nil { + t.Fatal("expected error on malformed string") + } else if !strings.HasSuffix(err.Error(), "DecodeHex: encoding/hex: invalid byte: U+005F '_'") { + t.Fatalf("unexpected error: %q", err) + } +} + +func TestEncoding_Hex_Fails(t *testing.T) { + testAll(t, func(group *testGroup) { + g := group.group + scalar := g.NewScalar().Random() + testEncoding(t, scalar, g.NewScalar()) + + scalar = g.NewScalar().Random() + element := g.Base().Multiply(scalar) + testEncoding(t, element, g.NewElement()) + + // Hex fails + testDecodingHexFails(t, scalar, g.NewScalar()) + testDecodingHexFails(t, element, g.NewElement()) + + // Doesn't yield the same decoded result + scalar = g.NewScalar().Random() + s := g.NewScalar() + if err := s.DecodeHex(scalar.Hex()); err != nil { + t.Fatalf("unexpected error on valid encoding: %s", err) + } + + if s.Equal(scalar) != 1 { + t.Fatal(errExpectedEquality) + } + + element = g.Base().Multiply(scalar) + e := g.NewElement() + if err := e.DecodeHex(element.Hex()); err != nil { + t.Fatalf("unexpected error on valid encoding: %s", err) + } + + if e.Equal(element) != 1 { + t.Fatal(errExpectedEquality) + } + }) +} diff --git a/tests/utils_test.go b/tests/utils_test.go index f334ba6..66242b4 100644 --- a/tests/utils_test.go +++ b/tests/utils_test.go @@ -9,12 +9,9 @@ package group_test import ( - "bytes" - "encoding" "encoding/hex" "errors" "fmt" - "strings" "testing" "github.com/bytemare/crypto" @@ -96,115 +93,3 @@ func decodeElement(t *testing.T, g crypto.Group, input string) *crypto.Element { return e } - -type serde interface { - Encode() []byte - Decode(data []byte) error - Hex() string - DecodeHex(h string) error - encoding.BinaryMarshaler - encoding.BinaryUnmarshaler -} - -func testEncoding(t *testing.T, thing1, thing2 serde) { - // empty string - if err := thing2.DecodeHex(""); err == nil { - t.Fatal("expected error on empty string") - } - - encoded := thing1.Encode() - marshalled, _ := thing1.MarshalBinary() - hexed := thing1.Hex() - - if !bytes.Equal(encoded, marshalled) { - t.Fatalf("Encode() and MarshalBinary() are expected to have the same output."+ - "\twant: %v\tgot : %v", encoded, marshalled) - } - - if hex.EncodeToString(encoded) != hexed { - t.Fatalf("Failed hex encoding, want %q, got %q", hex.EncodeToString(encoded), hexed) - } - - if err := thing2.Decode(nil); err == nil { - t.Fatal("expected error on Decode() with nil input") - } - - if err := thing2.Decode(encoded); err != nil { - t.Fatalf("Decode() failed on a valid encoding: %v. Value: %v", err, hex.EncodeToString(encoded)) - } - - if err := thing2.UnmarshalBinary(encoded); err != nil { - t.Fatalf("UnmarshalBinary() failed on a valid encoding: %v", err) - } - - if err := thing2.DecodeHex(hexed); err != nil { - t.Fatalf("DecodeHex() failed on valid hex encoding: %v", err) - } -} - -func TestEncoding(t *testing.T) { - testAll(t, func(group *testGroup) { - g := group.group - scalar := g.NewScalar().Random() - testEncoding(t, scalar, g.NewScalar()) - - scalar = g.NewScalar().Random() - element := g.Base().Multiply(scalar) - testEncoding(t, element, g.NewElement()) - }) -} - -func testDecodingHexFails(t *testing.T, thing1, thing2 serde) { - // empty string - if err := thing2.DecodeHex(""); err == nil { - t.Fatal("expected error on empty string") - } - - // malformed string - hexed := thing1.Hex() - malformed := []rune(hexed) - malformed[0] = []rune("_")[0] - - if err := thing2.DecodeHex(string(malformed)); err == nil { - t.Fatal("expected error on malformed string") - } else if !strings.HasSuffix(err.Error(), "DecodeHex: encoding/hex: invalid byte: U+005F '_'") { - t.Fatalf("unexpected error: %q", err) - } -} - -func TestEncoding_Hex_Fails(t *testing.T) { - testAll(t, func(group *testGroup) { - g := group.group - scalar := g.NewScalar().Random() - testEncoding(t, scalar, g.NewScalar()) - - scalar = g.NewScalar().Random() - element := g.Base().Multiply(scalar) - testEncoding(t, element, g.NewElement()) - - // Hex fails - testDecodingHexFails(t, scalar, g.NewScalar()) - testDecodingHexFails(t, element, g.NewElement()) - - // Doesn't yield the same decoded result - scalar = g.NewScalar().Random() - s := g.NewScalar() - if err := s.DecodeHex(scalar.Hex()); err != nil { - t.Fatalf("unexpected error on valid encoding: %s", err) - } - - if s.Equal(scalar) != 1 { - t.Fatal(errExpectedEquality) - } - - element = g.Base().Multiply(scalar) - e := g.NewElement() - if err := e.DecodeHex(element.Hex()); err != nil { - t.Fatalf("unexpected error on valid encoding: %s", err) - } - - if e.Equal(element) != 1 { - t.Fatal(errExpectedEquality) - } - }) -}