diff --git a/grpc_publisher_service.py b/grpc_publisher_service.py index da87c5b..ea19ad7 100644 --- a/grpc_publisher_service.py +++ b/grpc_publisher_service.py @@ -34,126 +34,127 @@ logger = logging.getLogger("[gRPC Publisher Service]") -def create_update_token_context( - device_id, account_identifier, platform_name, response, context -): - """ - Creates a context-specific token update function. - - Args: - device_id (str): The unique identifier of the device. - account_identifier (str): The identifier for the account - (e.g., email or username). - platform_name (str): The name of the platform (e.g., 'gmail'). - response (protobuf message class): The response class for the gRPC method. - context (grpc.ServicerContext): The gRPC context for the current method call. - - Returns: - function: A function `update_token(token)` that updates the token information. - """ - - def update_token(token, **kwargs): +class PublisherService(publisher_pb2_grpc.PublisherServicer): + """Publisher Service Descriptor""" + + def handle_create_grpc_error_response( + self, context, response, sys_msg, status_code, **kwargs + ): """ - Updates the stored token for a specific entity. + Handles the creation of a gRPC error response. Args: - token (dict or object): The token information - containing access and refresh tokens. + context: gRPC context. + response: gRPC response object. + sys_msg (str or tuple): System message. + status_code: gRPC status code. + user_msg (str or tuple): User-friendly message. + error_type (str): Type of error. + + Returns: + An instance of the specified response with the error set. """ - logger.info( - "Updating token for device_id: %s, platform: %s", - device_id, - platform_name, - ) - - update_entity_token_response, update_entity_token_error = update_entity_token( - device_id=device_id, - token=json.dumps(token), - account_identifier=account_identifier, - platform=platform_name, - ) - - if update_entity_token_error: - return error_response( - context, - response, - update_entity_token_error.details(), - update_entity_token_error.code(), - ) + user_msg = kwargs.get("user_msg") + error_type = kwargs.get("error_type") - if not update_entity_token_response.success: - return response( - message=update_entity_token_response.message, - success=update_entity_token_response.success, - ) + if not user_msg: + user_msg = sys_msg + + if error_type == "UNKNOWN": + logger.exception(sys_msg, exc_info=True) + else: + logger.error(sys_msg) - return True + context.set_details(user_msg) + context.set_code(status_code) - return update_token + return response() + def handle_request_field_validation( + self, context, request, response, required_fields + ): + """ + Validates the fields in the gRPC request. -def error_response(context, response, sys_msg, status_code, user_msg=None, _type=None): - """ - Create an error response. + Args: + context: gRPC context. + request: gRPC request object. + response: gRPC response object. + required_fields (list): List of required fields, can include tuples. + + Returns: + None or response: None if no missing fields, + error response otherwise. + """ - Args: - context: gRPC context. - response: gRPC response object. - sys_msg (str or tuple): System message. - status_code: gRPC status code. - user_msg (str or tuple): User-friendly message. - _type (str): Type of error. + def validate_field(field): + if not getattr(request, field, None): + return self.handle_create_grpc_error_response( + context, + response, + f"Missing required field: {field}", + grpc.StatusCode.INVALID_ARGUMENT, + ) - Returns: - An instance of the specified response with the error set. - """ - if not user_msg: - user_msg = sys_msg + return None - if isinstance(user_msg, tuple): - user_msg = "".join(user_msg) - if isinstance(sys_msg, tuple): - sys_msg = "".join(sys_msg) + for field in required_fields: + validation_error = validate_field(field) + if validation_error: + return validation_error - if _type == "UNKNOWN": - logger.exception(sys_msg, exc_info=True) - else: - logger.error(sys_msg) + return None - context.set_details(user_msg) - context.set_code(status_code) + def create_token_update_handler(self, response_cls, grpc_context, **kwargs): + """ + Creates a function to handle updating the token for a specific device and account. - return response() + Args: + device_id (str): The unique identifier of the device. + account_id (str): The identifier for the account (e.g., email or username). + platform (str): The name of the platform (e.g., 'gmail'). + response_cls (protobuf message class): The response class for the gRPC method. + grpc_context (grpc.ServicerContext): The gRPC context for the current method call. + + Returns: + function: A function `handle_token_update(token)` that updates the token information. + """ + device_id = kwargs["device_id"] + account_id = kwargs["account_id"] + platform = kwargs["platform"] + def handle_token_update(token, **kwargs): + """ + Handles updating the stored token for the specified device and account. -def validate_request_fields(context, request, response, required_fields): - """ - Validates the fields in the gRPC request. + Args: + token (dict or object): The token information containing access and refresh tokens. + """ - Args: - context: gRPC context. - request: gRPC request object. - response: gRPC response object. - required_fields (list): List of required fields. + update_response, update_error = update_entity_token( + device_id=device_id, + token=json.dumps(token), + account_identifier=account_id, + platform=platform, + ) - Returns: - None or response: None if no missing fields, - error response otherwise. - """ - missing_fields = [field for field in required_fields if not getattr(request, field)] - if missing_fields: - return error_response( - context, - response, - f"Missing required fields: {', '.join(missing_fields)}", - grpc.StatusCode.INVALID_ARGUMENT, - ) + if update_error: + return self.handle_create_grpc_error_response( + grpc_context, + response_cls, + update_error.details(), + update_error.code(), + ) - return None + if not update_response.success: + return response_cls( + message=update_response.message, + success=update_response.success, + ) + return True -class PublisherService(publisher_pb2_grpc.PublisherServicer): - """Publisher Service Descriptor""" + return handle_token_update def GetOAuth2AuthorizationUrl(self, request, context): """Handles generating OAuth2 authorization URL""" @@ -161,25 +162,14 @@ def GetOAuth2AuthorizationUrl(self, request, context): response = publisher_pb2.GetOAuth2AuthorizationUrlResponse def validate_fields(): - return validate_request_fields( + return self.handle_request_field_validation( context, request, response, ["platform"], ) - try: - invalid_fields_response = validate_fields() - if invalid_fields_response: - return invalid_fields_response - - check_platform_supported(request.platform.lower(), "oauth2") - - oauth2_client = OAuth2Client(request.platform) - - if request.redirect_url: - oauth2_client.session.redirect_uri = request.redirect_url - + def handle_authorization(oauth2_client): extra_params = { "state": getattr(request, "state") or None, "code_verifier": getattr(request, "code_verifier") or None, @@ -202,8 +192,22 @@ def validate_fields(): message="Successfully generated authorization url", ) + try: + invalid_fields_response = validate_fields() + if invalid_fields_response: + return invalid_fields_response + + check_platform_supported(request.platform.lower(), "oauth2") + + oauth2_client = OAuth2Client(request.platform) + + if request.redirect_url: + oauth2_client.session.redirect_uri = request.redirect_url + + return handle_authorization(oauth2_client) + except NotImplementedError as e: - return error_response( + return self.handle_create_grpc_error_response( context, response, str(e), @@ -211,7 +215,7 @@ def validate_fields(): ) except Exception as exc: - return error_response( + return self.handle_create_grpc_error_response( context, response, exc, @@ -226,7 +230,7 @@ def ExchangeOAuth2CodeAndStore(self, request, context): response = publisher_pb2.ExchangeOAuth2CodeAndStoreResponse def validate_fields(): - return validate_request_fields( + return self.handle_request_field_validation( context, request, response, @@ -238,7 +242,7 @@ def list_tokens(): long_lived_token=request.long_lived_token ) if list_error: - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, list_error.details(), @@ -259,7 +263,7 @@ def fetch_token_and_profile(): ) if not token.get("refresh_token"): - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, "invalid token: No refresh token present.", @@ -270,7 +274,7 @@ def fetch_token_and_profile(): expected_scopes = set(scope) if not expected_scopes.issubset(fetched_scopes): - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, "invalid token: Scopes do not match. Expected: " @@ -292,7 +296,7 @@ def store_token(token, profile): ) if store_error: - return error_response( + return self.handle_create_grpc_error_response( context, response, store_error.details(), @@ -328,7 +332,7 @@ def store_token(token, profile): return store_token(*fetched_data) except OAuthError as e: - return error_response( + return self.handle_create_grpc_error_response( context, response, str(e), @@ -337,7 +341,7 @@ def store_token(token, profile): ) except NotImplementedError as e: - return error_response( + return self.handle_create_grpc_error_response( context, response, str(e), @@ -345,7 +349,7 @@ def store_token(token, profile): ) except Exception as exc: - return error_response( + return self.handle_create_grpc_error_response( context, response, exc, @@ -360,7 +364,7 @@ def RevokeAndDeleteOAuth2Token(self, request, context): response = publisher_pb2.RevokeAndDeleteOAuth2TokenResponse def validate_fields(): - return validate_request_fields( + return self.handle_request_field_validation( context, request, response, @@ -374,7 +378,7 @@ def get_access_token(): long_lived_token=request.long_lived_token, ) if get_access_token_error: - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, get_access_token_error.details(), @@ -398,7 +402,7 @@ def delete_token(): ) if delete_token_error: - return error_response( + return self.handle_create_grpc_error_response( context, response, delete_token_error.details(), @@ -428,7 +432,7 @@ def delete_token(): return delete_token() except NotImplementedError as e: - return error_response( + return self.handle_create_grpc_error_response( context, response, str(e), @@ -436,7 +440,7 @@ def delete_token(): ) except Exception as exc: - return error_response( + return self.handle_create_grpc_error_response( context, response, exc, @@ -451,14 +455,16 @@ def PublishContent(self, request, context): response = publisher_pb2.PublishContentResponse def validate_fields(): - return validate_request_fields(context, request, response, ["content"]) + return self.handle_request_field_validation( + context, request, response, ["content"] + ) def decode_payload(): platform_letter, encrypted_content, device_id, decode_error = ( decode_relay_sms_payload(request.content) ) if decode_error: - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, decode_error, @@ -473,7 +479,7 @@ def get_platform_info(platform_letter): platform_letter ) if platform_info is None: - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, platform_err, @@ -488,7 +494,7 @@ def get_access_token(device_id, platform_name, account_identifier): account_identifier=account_identifier, ) if get_access_token_error: - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, get_access_token_error.details(), @@ -506,7 +512,7 @@ def decrypt_message(device_id, encrypted_content): device_id.hex(), base64.b64encode(encrypted_content).decode("utf-8") ) if decrypt_payload_error: - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, decrypt_payload_error.details(), @@ -524,7 +530,7 @@ def encrypt_message(device_id, plaintext): device_id.hex(), plaintext ) if encrypt_payload_error: - return None, error_response( + return None, self.handle_create_grpc_error_response( context, response, encrypt_payload_error.details(), @@ -541,7 +547,7 @@ def handle_oauth2_email(device_id, platform_name, service_type, payload, token): content_parts, parse_error = parse_content(service_type, payload) if parse_error: - return error_response( + return self.handle_create_grpc_error_response( context, response, parse_error, @@ -560,8 +566,12 @@ def handle_oauth2_email(device_id, platform_name, service_type, payload, token): oauth2_client = OAuth2Client( platform_name, json.loads(token), - create_update_token_context( - device_id.hex(), from_email, platform_name, response, context + self.create_token_update_handler( + device_id=device_id.hex(), + account_id=from_email, + platform=platform_name, + response_cls=response, + grpc_context=context, ), ) return oauth2_client.send_message(email_message, from_email) @@ -570,7 +580,7 @@ def handle_oauth2_text(device_id, platform_name, service_type, payload, token): content_parts, parse_error = parse_content(service_type, payload) if parse_error: - return error_response( + return self.handle_create_grpc_error_response( context, response, parse_error, @@ -581,8 +591,12 @@ def handle_oauth2_text(device_id, platform_name, service_type, payload, token): oauth2_client = OAuth2Client( platform_name, json.loads(token), - create_update_token_context( - device_id.hex(), sender, platform_name, response, context + self.create_token_update_handler( + device_id=device_id.hex(), + account_id=sender, + platform=platform_name, + response_cls=response, + grpc_context=context, ), ) return oauth2_client.send_message(text) @@ -646,7 +660,7 @@ def handle_oauth2_text(device_id, platform_name, service_type, payload, token): ) except OAuthError as e: - return error_response( + return self.handle_create_grpc_error_response( context, response, str(e), @@ -655,7 +669,7 @@ def handle_oauth2_text(device_id, platform_name, service_type, payload, token): ) except Exception as exc: - return error_response( + return self.handle_create_grpc_error_response( context, response, exc, diff --git a/grpc_vault_entity_client.py b/grpc_vault_entity_client.py index 17e6088..a97afd1 100644 --- a/grpc_vault_entity_client.py +++ b/grpc_vault_entity_client.py @@ -1,6 +1,7 @@ """Vault gRPC Client""" import logging +import functools import grpc import vault_pb2 @@ -44,7 +45,34 @@ def get_channel(internal=True): return grpc.insecure_channel(f"{hostname}:{port}") -def store_entity_token(long_lived_token, token, platform, account_identifier): +def grpc_call(internal=True): + """Decorator to handle gRPC calls.""" + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + channel = get_channel(internal) + + with channel as conn: + kwargs["stub"] = ( + vault_pb2_grpc.EntityInternalStub(conn) + if internal + else vault_pb2_grpc.EntityStub(conn) + ) + return func(*args, **kwargs) + except grpc.RpcError as e: + return None, e + except Exception as e: + raise e + + return wrapper + + return decorator + + +@grpc_call() +def store_entity_token(long_lived_token, token, platform, account_identifier, **kwargs): """Store an entity token in the vault. Args: @@ -58,29 +86,23 @@ def store_entity_token(long_lived_token, token, platform, account_identifier): - server response (object): The vault server response. - error (Exception): The error encountered if the request fails, otherwise None. """ - try: - channel = get_channel() - - with channel as conn: - stub = vault_pb2_grpc.EntityInternalStub(conn) - request = vault_pb2.StoreEntityTokenRequest( - long_lived_token=long_lived_token, - token=token, - platform=platform, - account_identifier=account_identifier, - ) - - logger.debug("Storing token for platform '%s'", platform) - response = stub.StoreEntityToken(request) - logger.info("Successfully stored token for platform '%s'", platform) - return response, None - except grpc.RpcError as e: - return None, e - except Exception as e: - raise e - - -def list_entity_stored_tokens(long_lived_token): + stub = kwargs["stub"] + + request = vault_pb2.StoreEntityTokenRequest( + long_lived_token=long_lived_token, + token=token, + platform=platform, + account_identifier=account_identifier, + ) + + logger.debug("Storing token for platform '%s'", platform) + response = stub.StoreEntityToken(request) + logger.info("Successfully stored token for platform '%s'", platform) + return response, None + + +@grpc_call(False) +def list_entity_stored_tokens(long_lived_token, **kwargs): """Fetches and lists an entity's stored tokens from the vault. Args: @@ -93,33 +115,22 @@ def list_entity_stored_tokens(long_lived_token): - error (Exception): The error encountered if the request fails, otherwise None. """ - try: - channel = get_channel(internal=False) - - with channel as conn: - stub = vault_pb2_grpc.EntityStub(conn) - request = vault_pb2.ListEntityStoredTokensRequest( - long_lived_token=long_lived_token - ) - - logger.debug( - "Sending request to list stored tokens for long_lived_token: %s", - long_lived_token, - ) - response = stub.ListEntityStoredTokens(request) - tokens = response.stored_tokens - - logger.info("Successfully retrieved stored tokens.") - return tokens, None - except grpc.RpcError as e: - return None, e - except Exception as e: - raise e - - -def get_entity_access_token( - platform, account_identifier, device_id=None, long_lived_token=None -): + stub = kwargs["stub"] + request = vault_pb2.ListEntityStoredTokensRequest(long_lived_token=long_lived_token) + + logger.debug( + "Sending request to list stored tokens for long_lived_token: %s", + long_lived_token, + ) + response = stub.ListEntityStoredTokens(request) + tokens = response.stored_tokens + + logger.info("Successfully retrieved stored tokens.") + return tokens, None + + +@grpc_call() +def get_entity_access_token(platform, account_identifier, **kwargs): """ Retrieves an entity access token. @@ -134,39 +145,35 @@ def get_entity_access_token( - server response (object): The vault server response. - error (Exception): The error encountered if the request fails, otherwise None. """ - try: - channel = get_channel() - - with channel as conn: - stub = vault_pb2_grpc.EntityInternalStub(conn) - request = vault_pb2.GetEntityAccessTokenRequest( - device_id=device_id, - long_lived_token=long_lived_token, - platform=platform, - account_identifier=account_identifier, - ) - - identifier = device_id or long_lived_token - logger.debug( - "Requesting access tokens for %s '%s'...", - "device_id" if device_id else "long_lived_token", - identifier, - ) - response = stub.GetEntityAccessToken(request) - - logger.info( - "Successfully retrieved access token for %s '%s'.", - "device_id" if device_id else "long_lived_token", - identifier, - ) - return response, None - except grpc.RpcError as e: - return None, e - except Exception as e: - raise e - - -def decrypt_payload(device_id, payload_ciphertext): + stub = kwargs["stub"] + device_id = kwargs.get("device_id") + long_lived_token = kwargs.get("long_lived_token") + + request = vault_pb2.GetEntityAccessTokenRequest( + device_id=device_id, + long_lived_token=long_lived_token, + platform=platform, + account_identifier=account_identifier, + ) + + identifier = device_id or long_lived_token + logger.debug( + "Requesting access tokens for %s '%s'...", + "device_id" if device_id else "long_lived_token", + identifier, + ) + response = stub.GetEntityAccessToken(request) + + logger.info( + "Successfully retrieved access token for %s '%s'.", + "device_id" if device_id else "long_lived_token", + identifier, + ) + return response, None + + +@grpc_call() +def decrypt_payload(device_id, payload_ciphertext, **kwargs): """ Decrypts the payload. @@ -179,29 +186,22 @@ def decrypt_payload(device_id, payload_ciphertext): - server response (object): The vault server response. - error (Exception): The error encountered if the request fails, otherwise None. """ - try: - channel = get_channel() - - with channel as conn: - stub = vault_pb2_grpc.EntityInternalStub(conn) - request = vault_pb2.DecryptPayloadRequest( - device_id=device_id, payload_ciphertext=payload_ciphertext - ) - - logger.debug( - "Sending request to decrypt payload for device_id: %s", - device_id, - ) - response = stub.DecryptPayload(request) - logger.info("Successfully decrypted payload.") - return response, None - except grpc.RpcError as e: - return None, e - except Exception as e: - raise e - - -def encrypt_payload(device_id, payload_plaintext): + stub = kwargs["stub"] + request = vault_pb2.DecryptPayloadRequest( + device_id=device_id, payload_ciphertext=payload_ciphertext + ) + + logger.debug( + "Sending request to decrypt payload for device_id: %s", + device_id, + ) + response = stub.DecryptPayload(request) + logger.info("Successfully decrypted payload.") + return response, None + + +@grpc_call() +def encrypt_payload(device_id, payload_plaintext, **kwargs): """ Encrypts the payload. @@ -214,29 +214,22 @@ def encrypt_payload(device_id, payload_plaintext): - server response (object): The vault server response. - error (Exception): The error encountered if the request fails, otherwise None. """ - try: - channel = get_channel() - - with channel as conn: - stub = vault_pb2_grpc.EntityInternalStub(conn) - request = vault_pb2.EncryptPayloadRequest( - device_id=device_id, payload_plaintext=payload_plaintext - ) - - logger.debug( - "Sending request to encrypt payload for device_id: %s", - device_id, - ) - response = stub.EncryptPayload(request) - logger.info("Successfully encrypted payload.") - return response, None - except grpc.RpcError as e: - return None, e - except Exception as e: - raise e - - -def update_entity_token(device_id, token, platform, account_identifier): + stub = kwargs["stub"] + request = vault_pb2.EncryptPayloadRequest( + device_id=device_id, payload_plaintext=payload_plaintext + ) + + logger.debug( + "Sending request to encrypt payload for device_id: %s", + device_id, + ) + response = stub.EncryptPayload(request) + logger.info("Successfully encrypted payload.") + return response, None + + +@grpc_call() +def update_entity_token(device_id, token, platform, account_identifier, **kwargs): """Update an entity's token in the vault. Args: @@ -250,29 +243,22 @@ def update_entity_token(device_id, token, platform, account_identifier): - server response (object): The vault server response. - error (Exception): The error encountered if the request fails, otherwise None. """ - try: - channel = get_channel() - - with channel as conn: - stub = vault_pb2_grpc.EntityInternalStub(conn) - request = vault_pb2.UpdateEntityTokenRequest( - device_id=device_id, - token=token, - platform=platform, - account_identifier=account_identifier, - ) - - logger.debug("Updating token for platform '%s'", platform) - response = stub.UpdateEntityToken(request) - logger.info("Successfully updated token for platform '%s'", platform) - return response, None - except grpc.RpcError as e: - return None, e - except Exception as e: - raise e - - -def delete_entity_token(long_lived_token, platform, account_identifier): + stub = kwargs["stub"] + request = vault_pb2.UpdateEntityTokenRequest( + device_id=device_id, + token=token, + platform=platform, + account_identifier=account_identifier, + ) + + logger.debug("Updating token for platform '%s'", platform) + response = stub.UpdateEntityToken(request) + logger.info("Successfully updated token for platform '%s'", platform) + return response, None + + +@grpc_call() +def delete_entity_token(long_lived_token, platform, account_identifier, **kwargs): """Delete an entity's token in the vault. Args: @@ -285,22 +271,14 @@ def delete_entity_token(long_lived_token, platform, account_identifier): - server response (object): The vault server response. - error (Exception): The error encountered if the request fails, otherwise None. """ - try: - channel = get_channel() - - with channel as conn: - stub = vault_pb2_grpc.EntityInternalStub(conn) - request = vault_pb2.DeleteEntityTokenRequest( - long_lived_token=long_lived_token, - platform=platform, - account_identifier=account_identifier, - ) - - logger.debug("Deleting token for platform '%s'", platform) - response = stub.DeleteEntityToken(request) - logger.info("Successfully deleted token for platform '%s'", platform) - return response, None - except grpc.RpcError as e: - return None, e - except Exception as e: - raise e + stub = kwargs["stub"] + request = vault_pb2.DeleteEntityTokenRequest( + long_lived_token=long_lived_token, + platform=platform, + account_identifier=account_identifier, + ) + + logger.debug("Deleting token for platform '%s'", platform) + response = stub.DeleteEntityToken(request) + logger.info("Successfully deleted token for platform '%s'", platform) + return response, None