diff --git a/oauth2/__init__.py b/oauth2/__init__.py index 65c3f071..e782776f 100644 --- a/oauth2/__init__.py +++ b/oauth2/__init__.py @@ -35,6 +35,17 @@ except ImportError: from cgi import parse_qs, parse_qsl +try: + from Crypto.PublicKey import RSA + from Crypto.Util.number import long_to_bytes, bytes_to_long +except ImportError: + RSA=None + +try: + from hashlib import sha1 as sha +except ImportError: + import sha # Deprecated + VERSION = '1.0' # Hi Blaine! HTTP_METHOD = 'GET' @@ -705,7 +716,7 @@ def check(self, request, consumer, token, signature): built = self.sign(request, consumer, token) return built == signature - + class SignatureMethod_HMAC_SHA1(SignatureMethod): name = 'HMAC-SHA1' @@ -730,17 +741,67 @@ def sign(self, request, consumer, token): key, raw = self.signing_base(request, consumer, token) # HMAC object. - try: - from hashlib import sha1 as sha - except ImportError: - import sha # Deprecated - hashed = hmac.new(key, raw, sha) # Calculate the digest base 64. return binascii.b2a_base64(hashed.digest())[:-1] +class SignatureMethod_RSA_SHA1(SignatureMethod): + name = 'RSA-SHA1' + + def signing_base(self, request, consumer, token): + if request.normalized_url is None: + raise ValueError("Base URL for request is not set.") + + sig = ( + escape(request.method), + escape(request.normalized_url), + escape(request.get_normalized_parameters()), + ) + + key = consumer.secret + raw = '&'.join(sig) + return key, raw + + def sign(self, request, consumer, token): + """Builds the base signature string.""" + if RSA is None: raise NotImplementedError, self.name + key, raw = self.signing_base(request, consumer, token) + + digest = sha(raw).digest() + sig = key.sign(self._pkcs1imify(key, digest), '')[0] + sig_bytes = long_to_bytes(sig) + # Calculate the digest base 64. + return binascii.b2a_base64(sig_bytes)[:-1] + + def check(self, request, consumer, token, signature): + """Returns whether the given signature is the correct signature for + the given consumer and token signing the given request.""" + if RSA is None: raise NotImplementedError, self.name + key, raw = self.signing_base(request, consumer, token) + + digest = sha(raw).digest() + sig = bytes_to_long(binascii.a2b_base64(signature)) + data = self._pkcs1imify(key, digest) + + pubkey = key.publickey() + return pubkey.verify(data, (sig,)) + + @staticmethod + def _pkcs1imify(key, data): + """Adapted from paramiko + + turn a 20-byte SHA1 hash into a blob of data as large as the key's N, + using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre. + """ + SHA1_DIGESTINFO = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14' + size = len(long_to_bytes(key.n)) + filler = '\xff' * (size - len(SHA1_DIGESTINFO) - len(data) - 3) + return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data + + + class SignatureMethod_PLAINTEXT(SignatureMethod): name = 'PLAINTEXT' diff --git a/tests/test_oauth.py b/tests/test_oauth.py index e2a87f97..5614a6da 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -664,6 +664,58 @@ def test_from_consumer_and_token(self): self.assertEquals(req['oauth_consumer_key'], con.key) self.assertEquals(tok.verifier, req['oauth_verifier']) +class TestRSASignature(unittest.TestCase): + + def setUp(self): + from Crypto.PublicKey import RSA + self.RSA = RSA + key=RSA.importKey('''-----BEGIN RSA PRIVATE KEY----- +MIIBOgIBAAJBAM7B+5TJsc93ymBSFtC5DE1qDlqvwio0xDfS6bZQTfFiHLm8pHXg +Atkm7QB6gvyRKm+a/G3qEbmBdz21Fw0RLJsCAwEAAQJAS68qnr5uPlnFVRj3jRQP +8s6dzoiD9Ns38I9eSgR/Y5ozl8r/cClLeGWvDKfXvrxlsaMuqWLZ5KMtamaRS9Fl +sQIhAPmOY+s5ZxsYtem+Uc2IUGexNoP/Ng7MPS3C+Q3L6K4nAiEA1Biv6i7TqAbx +oHulPIXb2Z9JmO46aT81n9WnD1qyim0CIF9eN/cLf8iOH+7MqYxHHJsT0QaOgEUV +bgfP68eG9kufAiEAtUSAHGp29HUyzxC9sNNKiVysnuqDu22NXBRSmjnOu6UCIEFZ +nqb0GVzfF6wbsf40mkp1kdHq/fNiFRrLYWWJSpGY +-----END RSA PRIVATE KEY-----''') + self.method = oauth.SignatureMethod_RSA_SHA1() + self.tok = oauth.Token(key="tok-test-key", secret="tok-test-secret") + self.con = oauth.Consumer(key="con-test-key", secret=key) + self.url = "http://sp.example.com/" + + self.params = { + 'oauth_version': "1.0", + 'oauth_nonce': "4572616e48616d6d65724c61686176", + 'oauth_timestamp': "137131200", + 'oauth_token': self.tok.key, + 'oauth_consumer_key': self.con.key, + 'bar': 'blerg', + 'multi': ['FOO','BAR'], + 'foo': 59 + } + + def test_sign(self): + req = oauth.Request(method="GET", url=self.url, parameters=self.params) + req.sign_request(self.method, self.con, self.tok) + self.assertEquals(req['oauth_signature_method'], 'RSA-SHA1') + self.assertEquals(req['oauth_signature'], 'D2rdx9TiFajZbXChqMca6eaal8FxZhLMU1bdNX0glIN+BT4nrYGJqmIW92kWZYEYKHsVz7e67oDBEYlIIQMKWg==') + + def test_verify(self): + self.params['oauth_timestamp'] = int(time.time()) + req = oauth.Request(method="GET", url=self.url, parameters=self.params) + server = oauth.Server() + server.add_signature_method(self.method) + + req.sign_request(self.method, self.con, self.tok) + parameters = server.verify_request(req, self.con,self.tok) + + self.assertTrue('bar' in parameters) + self.assertTrue('foo' in parameters) + self.assertTrue('multi' in parameters) + self.assertEquals(parameters['bar'], 'blerg') + self.assertEquals(parameters['foo'], 59) + self.assertEquals(parameters['multi'], ['FOO','BAR']) + class SignatureMethod_Bad(oauth.SignatureMethod): name = "BAD"