diff --git a/audio_transcription/fast_transcriber.py b/audio_transcription/fast_transcriber.py index 187dbb7..6249127 100644 --- a/audio_transcription/fast_transcriber.py +++ b/audio_transcription/fast_transcriber.py @@ -33,7 +33,8 @@ def audio_split_by_silence( min_speakers: int = -1, max_speakers: int = -1, min_silence_length: float = 0.8, - min_segment_length: float = -1 + min_segment_length: float = -1, + chunks: str = "", ): ''' :param file: Audio file @@ -47,6 +48,7 @@ def audio_split_by_silence( :param max_speakers: Maximum number of speakers to detect for diarization. Defaults to auto-detect when set to -1. :param min_silence_length: Minimum length of silence in seconds to use for splitting audio for parallel processing. Defaults to 0.8. :param min_segment_length: Minimum length of audio segment in seconds to use for splitting audio for parallel processing. If set to -1, we pick a value based on your settings. + :param chunks: A parameter to manually specify the start and end times of each chunk when splitting audio for parallel processing. If set to "", we use silence detection to split the audio. If set to a string formatted with a start and end second on each line, we use the specified chunks. Example: '0,10' and '10,20' on separate lines. ''' import os import sys @@ -151,11 +153,23 @@ def process_segment(segment): print(f"Took {time.time() - t:.2f} seconds to push segment from {start_time:.2f} to {end_time:.2f}") return whisper_job - segments = split_silences( - audio_path, - min_silence_length=min_silence_length, - min_segment_length=min_segment_length, - ) + if chunks == "": + segments = split_silences( + audio_path, + min_silence_length=min_silence_length, + min_segment_length=min_segment_length, + ) + else: + try: + # chunks is a string formatted with a start and end second on each line + segments = [ + tuple(map(float, line.split(","))) + for line in chunks.strip().split("\n") + ] + except: + raise ValueError( + "Invalid chunks format. Please provide a string formatted with a start and end second on each line. Example: '0,10\n10,20\n20,30'" + ) if not segments: segments.append(whisper.push(sieve.File(path=file.path), language=source_language, word_level_timestamps=word_level_timestamps, speed_boost=speed_boost, decode_boost=decode_boost))