From 986172dcf4d6b3953c32bfbed32d44a7d45cf93e Mon Sep 17 00:00:00 2001 From: agentmarketbot Date: Sun, 26 Jan 2025 15:51:45 +0000 Subject: [PATCH] Add OpenAI Whisper transcription service support Add alternative audio transcription service using OpenAI's Whisper API alongside existing AWS transcription. Key changes include: - Create new OpenAITranscriber class to handle Whisper API requests - Modify AudioTranscriber to support both AWS and OpenAI services - Add configuration options for transcription service selection - Add OPENAI_API_KEY and TRANSCRIPTION_SERVICE env variables - Make AWS services optional when using OpenAI transcription The system now defaults to AWS but can be switched to OpenAI's Whisper via the TRANSCRIPTION_SERVICE environment variable ('aws' or 'openai'). --- bot_handlers.py | 12 ++++++--- config.py | 2 ++ services.py | 70 ++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/bot_handlers.py b/bot_handlers.py index 74b6986..0149074 100644 --- a/bot_handlers.py +++ b/bot_handlers.py @@ -8,9 +8,15 @@ logger = logging.getLogger(__name__) # Initialize services -aws_services = AWSServices() -audio_transcriber = AudioTranscriber(aws_services) -text_summarizer = TextSummarizer(os.environ.get('MARKETROUTER_API_KEY')) +from config import Config + +aws_services = AWSServices() if Config.TRANSCRIPTION_SERVICE == 'aws' else None +audio_transcriber = AudioTranscriber( + aws_services=aws_services, + openai_api_key=Config.OPENAI_API_KEY, + service=Config.TRANSCRIPTION_SERVICE +) +text_summarizer = TextSummarizer(Config.MARKETROUTER_API_KEY) def handle_update(update: Dict[str, Any]) -> None: if 'message' in update: diff --git a/config.py b/config.py index 4b02d69..38526c3 100644 --- a/config.py +++ b/config.py @@ -6,3 +6,5 @@ class Config: MARKETROUTER_API_KEY = os.environ.get('MARKETROUTER_API_KEY') AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID') AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY') + OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') + TRANSCRIPTION_SERVICE = os.environ.get('TRANSCRIPTION_SERVICE', 'aws') # 'aws' or 'openai' diff --git a/services.py b/services.py index 8d81328..ac1d33c 100644 --- a/services.py +++ b/services.py @@ -50,33 +50,71 @@ def start_transcription_job(self, job_name, media_uri, media_format='ogg', langu def get_transcription_job_status(self, job_name): return self.transcribe_client.get_transcription_job(TranscriptionJobName=job_name) +class OpenAITranscriber: + def __init__(self, api_key: str): + self.api_key = api_key + self.api_url = "https://api.openai.com/v1/audio/transcriptions" + + def transcribe_audio(self, audio_content: bytes) -> str: + try: + headers = { + "Authorization": f"Bearer {self.api_key}" + } + + files = { + 'file': ('audio.ogg', audio_content, 'audio/ogg'), + 'model': (None, 'whisper-1'), + } + + response = requests.post(self.api_url, headers=headers, files=files) + response.raise_for_status() + + return response.json()['text'] + except Exception as e: + logger.error(f"OpenAI transcription error: {e}") + raise + class AudioTranscriber: - def __init__(self, aws_services: AWSServices): + def __init__(self, aws_services: Optional[AWSServices] = None, openai_api_key: Optional[str] = None, service: str = 'aws'): + self.service = service self.aws_services = aws_services + self.openai_transcriber = OpenAITranscriber(openai_api_key) if openai_api_key else None self.bucket_name = 'audio-transcribe-temp' + if service == 'aws' and not aws_services: + raise ValueError("AWS services required for AWS transcription") + if service == 'openai' and not openai_api_key: + raise ValueError("OpenAI API key required for OpenAI transcription") + def transcribe_audio(self, file_url: str) -> str: try: - self.aws_services.create_s3_bucket_if_not_exists(self.bucket_name) - logger.info(f"S3 Bucket created/confirmed: {self.bucket_name}") - audio_content = self._download_audio(file_url) - object_key = f'audio_{uuid.uuid4()}.ogg' - s3_uri = self.aws_services.upload_file_to_s3(audio_content, self.bucket_name, object_key) - logger.info(f"S3 URI: {s3_uri}") - - job_name = f"whisper_job_{int(time.time())}" - self.aws_services.start_transcription_job(job_name, s3_uri) - logger.info(f"Transcription job started: {job_name}") - - transcription = self._wait_for_transcription(job_name) - self.aws_services.delete_file_from_s3(self.bucket_name, object_key) - - return transcription + + if self.service == 'openai': + return self.openai_transcriber.transcribe_audio(audio_content) + else: # aws + return self._transcribe_with_aws(audio_content) except Exception as e: logger.error(f"An error occurred: {e}") raise + def _transcribe_with_aws(self, audio_content: bytes) -> str: + self.aws_services.create_s3_bucket_if_not_exists(self.bucket_name) + logger.info(f"S3 Bucket created/confirmed: {self.bucket_name}") + + object_key = f'audio_{uuid.uuid4()}.ogg' + s3_uri = self.aws_services.upload_file_to_s3(audio_content, self.bucket_name, object_key) + logger.info(f"S3 URI: {s3_uri}") + + job_name = f"whisper_job_{int(time.time())}" + self.aws_services.start_transcription_job(job_name, s3_uri) + logger.info(f"Transcription job started: {job_name}") + + transcription = self._wait_for_transcription(job_name) + self.aws_services.delete_file_from_s3(self.bucket_name, object_key) + + return transcription + def _download_audio(self, file_url: str) -> bytes: response = requests.get(file_url) response.raise_for_status()