Skip to content

Commit

Permalink
Merge pull request #7 from bcc-code/feat/upgrade-to-whisper-3
Browse files Browse the repository at this point in the history
Updgrade to whisper 3
  • Loading branch information
KillerX authored Aug 30, 2024
2 parents d48fe32 + 02cd2e6 commit 7a02f6d
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 43 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
tmp
whisper-api
env
build/
build/
.DS_Store
.idea
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ run:
go run .

build:
rm -r ./build
rm -fr ./build
go build -o build/bin .
mkdir ./build/bcc-whisper
cp ./bcc-whisper/*.py ./build/bcc-whisper/
cp ./bcc-whisper/*.txt ./build/bcc-whisper/

build-linux-amd64:
rm -r ./build
rm -fr ./build
GOOS=linux GOARCH=amd64 go build -o build/bin .
mkdir ./build/bcc-whisper
cp ./bcc-whisper/*.py ./build/bcc-whisper/
Expand Down
138 changes: 102 additions & 36 deletions bcc-whisper/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import argparse
import copy
import json
import math
import ssl
import torch
import whisper_timestamped as whisper
from datasets import Dataset, Audio
#import whisper_timestamped as whisper
import whisper
import os
import datetime
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

from inference import load_model, inference, AudioClassifier

Expand All @@ -20,7 +24,7 @@
def parse_arguments():
parser = argparse.ArgumentParser(description="", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-l", "--language", help="language", default="no")
parser.add_argument("-m", "--model", help="whisper model to use", default="medium")
parser.add_argument("-m", "--model", help="whisper model to use", default="openai/whisper-large-v3")
parser.add_argument("src", help="source file")
parser.add_argument("output", help="output file")

Expand All @@ -30,84 +34,146 @@ def parse_arguments():
def main():
# import for side effects
_ = AudioClassifier
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

config = vars(parse_arguments())
file = config["src"]
out = config["output"]
language = config["language"]

model = config["model"]
transcribe_file(device, file, out, language, model)

def transcribe_file(device: torch.device, file: str, out: str, language: str, model: str):
transcribe_file(file, out, language, model)

def transcribe_file(file: str, out: str, language: str, model_id: str):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

detection_model = load_model(device)
detection_model.eval()
res = inference(detection_model, file, device, SAMPLE_RATE, SAMPLES_PER_CHUNK, LENGTH)
classified_segments = inference(detection_model, file, device, SAMPLE_RATE, SAMPLES_PER_CHUNK, LENGTH)

current_type = "song"

# start and end are in chunks / segments. So timestamp is start * LENGTH and end * LENGTH
start = 0
end = 0

res2: list[tuple[str, int, int]] = []
speech_segments: list[tuple[str, int, int]] = []

for x in res:
# Filter out non-speech and join continuous speech segments
for x in classified_segments:
d = x[1]
# The sensitivity can be adjusted here a bit
if d <= 2 or current_type == x[0]:
if d <= 3 or current_type == x[0]:
end += d
else:
print(x)
if current_type == "speech":
res2.append((current_type, start * LENGTH, end * LENGTH))
speech_segments.append((current_type, start * LENGTH, end * LENGTH))
start = end
end = end + x[1]
current_type = x[0]

if current_type == "speech":
res2.append((current_type, start * LENGTH, end * LENGTH))

audio = whisper.load_audio(file, SAMPLE_RATE)
speech_segments.append((current_type, start * LENGTH, end * LENGTH))

print(res2)

## This works but currently don't know what to do with it, so it is disabled
#p = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token="<REPLACE WITH TOKEN>")
#diarization = p(file)
#print(diarization)

# Load aduio file and just fetch the audio part of it
audio = Dataset.from_dict( {"audio": [file]}).cast_column("audio", Audio())[0]["audio"]

# Load model and set up the pipeline
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)

parts = {
"text": "",
"segments": [],
"language": language,
}

loaded_model = whisper.load_model(model)
# Transcribe each speech segment
for r in speech_segments:
from_index = math.floor(r[1] * audio["sampling_rate"])
to_index = math.floor(r[2] * audio["sampling_rate"])

# No kwargs is automatic language detection
kwargs = None
if language != "" and language != "auto":
kwargs = {"language": language}

current_language = language
for r in res2:
from_index = math.floor(r[1] * SAMPLE_RATE)
to_index = math.floor(r[2] * SAMPLE_RATE)
# If we do not copy it then in the next loop the ["array"] no longer exists for some dumb reason
a2 = copy.deepcopy(audio)

if language == "" or language == "auto":
# detect the spoken language
mel = (whisper.log_mel_spectrogram(whisper.pad_or_trim(audio[from_index:to_index]))
.to(loaded_model.device))
_, probs = loaded_model.detect_language(mel)
current_language = max(probs, key=probs.get)
print(f"Detected language: {current_language}")
# Cut the chunk we want to transcribe out of the audio
a2["array"] = audio["array"][from_index:to_index]

result = whisper.transcribe(loaded_model, audio=audio[from_index:to_index],
verbose=True,
language=current_language)
# do the trnscription
result = pipe(a2, return_timestamps="word", return_language=True, generate_kwargs=kwargs)


# Adjust timestamps and forrmat of the data
if parts["text"] != "":
parts["text"] += "\n\n"

parts["text"] += result["text"]

for segment in result["segments"]:
segment["start"] += r[1]
segment["end"] += r[1]
for word in segment["words"]:
word["start"] += r[1]
word["end"] += r[1]
segment = None
word_count = 0
for rchunk in result["chunks"]:
if segment is None:
segment = {
"text": "",
"start": rchunk["timestamp"][0] + r[1],
"words": [],
}

rchunk["text"] = rchunk["text"].strip()
segment["end"] = rchunk["timestamp"][1] + r[1]
segment["text"] += " " + rchunk["text"]

segment["words"].append({
"text": rchunk["text"],
"start": rchunk["timestamp"][0] + r[1],
"end": rchunk["timestamp"][1] + r[1],
})

word_count+=1

# This controls what we consider one line in the srt file
if word_count > 11 or (word_count > 5 and rchunk['text'][-1] in "!?.:"):
parts["segments"].append(segment)
print(segment["text"])
segment = None
word_count = 0

# In the first round the language is not set... Yeah, I don't know what the deal is either
if (parts["language"] == "" or parts["language"] == "auto") and result["chunks"][0]['language'] is not None:
parts["language"] = result["chunks"][0]['language']

if segment is not None:
parts["segments"].append(segment)

print(parts)

# Write the results to files in various formats
out_file = out.rstrip("/") + "/" + os.path.basename(file)
f = open(out_file + ".json", "w")
f.write(json.dumps(parts))
Expand All @@ -129,7 +195,6 @@ def transcribe_file(device: torch.device, file: str, out: str, language: str, mo
f.write(to_txt(parts["segments"]))
f.close()


def convert_seconds_to_vtt_timestamp(seconds):
# Convert seconds to a timedelta object
delta = datetime.timedelta(seconds=seconds)
Expand Down Expand Up @@ -213,3 +278,4 @@ def to_txt(segments: []):

if __name__ == "__main__":
main()

6 changes: 4 additions & 2 deletions bcc-whisper/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
git+https://github.com/bcc-code/bccmedia-song-or-not.git
git+https://github.com/openai/whisper.git
git+https://github.com/linto-ai/whisper-timestamped
whisper
torch
torchaudio
transformers
accelerate
datasets[audio]
2 changes: 1 addition & 1 deletion transcription-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (h *handlers) SubmitJob(c *gin.Context) {
}

if job.Model == "" {
job.Model = "large-v3"
job.Model = "openai/whisper-large-v3"
}

if job.Priority >= 500 {
Expand Down
19 changes: 18 additions & 1 deletion transcription-runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/bcc-code/mediabank-bridge/log"
"github.com/samber/lo"
)

func doCallback(job *Job) {
Expand Down Expand Up @@ -47,7 +48,23 @@ func runJob(job *Job) {
return
}

cmd := exec.Command("python3", "bcc-whisper/main.py", "-l", job.Language, "-m", "large-v3", job.Path, job.OutputPath)
model := "openai/whisper-large-v3"
if lo.Contains([]string{
"openai/whisper-large-v2",
"openai/whisper-large-v3",
"openai/whisper-large",
"openai/whisper-medium",
"openai/whisper-small",
"openai/whisper-tiny",
"NbAiLab/nb-whisper-large",
"NbAiLab/nb-whisper-medium",
"NbAiLab/nb-whisper-small",
"NbAiLab/nb-whisper-tiny",
}, job.Model) {
model = job.Model
}

cmd := exec.Command("python3", "bcc-whisper/main.py", "-l", job.Language, "-m", model, job.Path, job.OutputPath)
cmd.Env = append(os.Environ(), "PYTHONUNBUFFERED=1")

stderr, _ := cmd.StderrPipe()
Expand Down

0 comments on commit 7a02f6d

Please sign in to comment.