From cad5d3504ee1b93a779df0d08d349db28f316cd1 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Tue, 4 Jun 2024 08:26:51 +0800 Subject: [PATCH] sm2: public recover from signature 2 --- sm2/sm2.go | 57 ++++++++++++++++++++++++++++++++----------------- sm2/sm2_test.go | 8 +++---- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/sm2/sm2.go b/sm2/sm2.go index 1ec36950..9af620d7 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -707,7 +707,7 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { var ErrInvalidSignature = errors.New("sm2: invalid signature") -// RecoverPublicKeysFromSM2Signature recovers two SM2 public keys from a given signature and hash. +// RecoverPublicKeysFromSM2Signature recovers two or four SM2 public keys from a given signature and hash. // It takes the hash and signature as input and returns the recovered public keys as []*ecdsa.PublicKey. // If the signature or hash is invalid, it returns an error. // The function follows the SM2 algorithm to recover the public keys. @@ -741,38 +741,55 @@ func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, er if s.IsZero() == 1 { return nil, ErrInvalidSignature } + // sBytes = (r+s)⁻¹ sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N)) if err != nil { return nil, err } + // r = (Rx + e) mod N // Rx = r - e r.Sub(e, c.N) if r.IsZero() == 1 { return nil, ErrInvalidSignature } - rBytes = r.Bytes(c.N) - tmp := make([]byte, len(rBytes)+1) - copy(tmp[1:], rBytes) + pointRx := make([]*bigmod.Nat, 0, 2) + pointRx = append(pointRx, r) + // check if Rx in (N, P), small probability event + s.Set(r) + s = s.Add(c.N.Nat(), c.P) + if s.CmpGeq(c.N.Nat()) == 1 { + pointRx = append(pointRx, s) + } + pubs := make([]*ecdsa.PublicKey, 0, 4) + bytes := make([]byte, len(rBytes)+1) compressFlags := []byte{compressed02, compressed03} - pks := make([]*ecdsa.PublicKey, 0, 2) - for _, flag := range compressFlags { - tmp[0] = flag - p0, err := c.newPoint().SetBytes(tmp) - if err != nil { - return nil, err - } - p0.Add(p0, p1) - p0.ScalarMult(p0, sBytes) - pk := new(ecdsa.PublicKey) - pk.Curve = c.curve - pk.X, pk.Y, err = c.pointToAffine(p0) - if err != nil { - return nil, err + // Rx has one or two possible values, so point R has two or four possible values + for _, x := range pointRx { + rBytes = x.Bytes(c.N) + copy(bytes[1:], rBytes) + for _, flag := range compressFlags { + bytes[0] = flag + // p0 = R + p0, err := c.newPoint().SetBytes(bytes) + if err != nil { + return nil, err + } + // p0 = R - [s]G + p0.Add(p0, p1) + // Pub = [(r + s)⁻¹](R - [s]G) + p0.ScalarMult(p0, sBytes) + pub := new(ecdsa.PublicKey) + pub.Curve = c.curve + pub.X, pub.Y, err = c.pointToAffine(p0) + if err != nil { + return nil, err + } + pubs = append(pubs, pub) } - pks = append(pks, pk) } - return pks, nil + + return pubs, nil } // VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index c998fb2b..97c79ea7 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -470,7 +470,7 @@ func TestSignVerify(t *testing.T) { } } -func TestRecoverSM2PublicKeyFromSig(t *testing.T) { +func TestRecoverPublicKeysFromSM2Signature(t *testing.T) { priv, _ := GenerateKey(rand.Reader) tests := []struct { name string @@ -493,19 +493,19 @@ func TestRecoverSM2PublicKeyFromSig(t *testing.T) { pubs, err := RecoverPublicKeysFromSM2Signature(hashValue, sig) if err != nil { - t.Fatalf("recover failed %v", err) + t.Fatalf("recover sig=%x, priv=%x, failed %v", sig, priv.D.Bytes(), err) } found := false for _, pub := range pubs { if !VerifyASN1(pub, hashValue, sig) { - t.Errorf("failed to verify hash") + t.Errorf("failed to verify hash for sig=%x, priv=%x", sig, priv.D.Bytes()) } if pub.Equal(&priv.PublicKey) { found = true } } if !found { - t.Errorf("recover failed, not found public key") + t.Errorf("recover failed, not found public key for sig=%x, priv=%x", sig, priv.D.Bytes()) } }) }