Skip to content

Commit

Permalink
Various fixes and features:
Browse files Browse the repository at this point in the history
* bumbed whisperx to latest version to fix bugs and improve alignment
* added `suppress_numerals` to improve diarization accuracy
* improved `faster-whisper` transcription arguments
  • Loading branch information
MahmoudAshraf97 committed Oct 6, 2023
1 parent 471b383 commit 86dea57
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 52 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ If your system has enough VRAM (>=10GB), you can use `diarize_parallel.py` inste
- `-a AUDIO_FILE_NAME`: The name of the audio file to be processed
- `--no-stem`: Disables source separation
- `--whisper-model`: The model to be used for ASR, default is `medium.en`
- `--suppress_numerals`: Transcribes numbers in their pronounced letters instead of digits, improves alignment accuracy

## Known Limitations
- Only tested on english but several other languages are supported
- Overlapping speakers are yet to be addressed, a possible approach would be to separate the audio file and isolate only one speaker, then feed it into the pipeline but this will need much more computation
- There might be some errors, please raise an issue if you encounter any.

Expand Down
127 changes: 97 additions & 30 deletions Whisper_Transcription_+_NeMo_Diarization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
"outputs": [],
"source": [
"!pip install git+https://github.com/m-bain/whisperX.git@4cb167a225c0ebaea127fd6049abfaa3af9f8bb4\n",
"!pip install git+https://github.com/m-bain/whisperX.git@051047bb25b740fed2ea93ca737499c37e8dc9d4\n",
"!pip install --no-build-isolation nemo_toolkit[asr]==1.19.1 \n",
"!pip install faster-whisper==0.7.1\n",
"!pip install git+https://github.com/facebookresearch/demucs#egg=demucs\n",
Expand All @@ -55,8 +55,7 @@
"from faster_whisper import WhisperModel\n",
"import whisperx\n",
"import torch\n",
"import librosa\n",
"import soundfile\n",
"from pydub import AudioSegment\n",
"from nemo.collections.asr.models.msdd_models import NeuralDiarizer\n",
"from deepmultilingualpunctuation import PunctuationModel\n",
"import re\n",
Expand Down Expand Up @@ -106,7 +105,6 @@
" \"ja\",\n",
" \"zh\",\n",
" \"uk\",\n",
" \"pt\",\n",
" \"ar\",\n",
" \"ru\",\n",
" \"pl\",\n",
Expand All @@ -115,6 +113,14 @@
" \"fa\",\n",
" \"el\",\n",
" \"tr\",\n",
" \"cs\",\n",
" \"da\",\n",
" \"he\",\n",
" \"vi\",\n",
" \"ko\",\n",
" \"ur\",\n",
" \"te\",\n",
" \"hi\",\n",
"]\n",
"\n",
"\n",
Expand Down Expand Up @@ -188,7 +194,7 @@
" ws, we, wrd = (\n",
" int(wrd_dict[\"start\"] * 1000),\n",
" int(wrd_dict[\"end\"] * 1000),\n",
" wrd_dict[\"text\"],\n",
" wrd_dict[\"word\"],\n",
" )\n",
" wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)\n",
" while wrd_pos > float(e):\n",
Expand Down Expand Up @@ -372,6 +378,60 @@
" )\n",
"\n",
"\n",
"def find_numeral_symbol_tokens(tokenizer):\n",
" numeral_symbol_tokens = [\n",
" -1,\n",
" ]\n",
" for token, token_id in tokenizer.get_vocab().items():\n",
" has_numeral_symbol = any(c in \"0123456789%$£\" for c in token)\n",
" if has_numeral_symbol:\n",
" numeral_symbol_tokens.append(token_id)\n",
" return numeral_symbol_tokens\n",
"\n",
"\n",
"def _get_next_start_timestamp(word_timestamps, current_word_index):\n",
" # if current word is the last word\n",
" if current_word_index == len(word_timestamps) - 1:\n",
" return word_timestamps[current_word_index][\"start\"]\n",
"\n",
" next_word_index = current_word_index + 1\n",
" while current_word_index < len(word_timestamps) - 1:\n",
" if word_timestamps[next_word_index].get(\"start\") is None:\n",
" # if next word doesn't have a start timestamp\n",
" # merge it with the current word and delete it\n",
" word_timestamps[current_word_index][\"word\"] += (\n",
" \" \" + word_timestamps[next_word_index][\"word\"]\n",
" )\n",
"\n",
" word_timestamps[next_word_index][\"word\"] = None\n",
" next_word_index += 1\n",
"\n",
" else:\n",
" return word_timestamps[next_word_index][\"start\"]\n",
"\n",
"\n",
"def filter_missing_timestamps(word_timestamps):\n",
" # handle the first and last word\n",
" if word_timestamps[0].get(\"start\") is None:\n",
" word_timestamps[0][\"start\"] = 0\n",
" word_timestamps[0][\"end\"] = _get_next_start_timestamp(word_timestamps, 0)\n",
"\n",
" result = [\n",
" word_timestamps[0],\n",
" ]\n",
"\n",
" for i, ws in enumerate(word_timestamps[1:], start=1):\n",
" # if ws doesn't have a start and end\n",
" # use the previous end as start and next start as end\n",
" if ws.get(\"start\") is None:\n",
" ws[\"start\"] = word_timestamps[i - 1][\"end\"]\n",
" ws[\"end\"] = _get_next_start_timestamp(word_timestamps, i)\n",
"\n",
" if ws[\"word\"] is not None:\n",
" result.append(ws)\n",
" return result\n",
"\n",
"\n",
"def cleanup(path: str):\n",
" \"\"\"path could either be relative or absolute.\"\"\"\n",
" # check if file or directory exists\n",
Expand Down Expand Up @@ -404,13 +464,16 @@
"outputs": [],
"source": [
"# Name of the audio file\n",
"audio_path = '9d54248d-60f5-4661-97d1-88b23568b2db.mp3'\n",
"audio_path = \"20200128-Pieter Wuille (part 1 of 2) - Episode 1.mp3\"\n",
"\n",
"# Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram\n",
"enable_stemming = True\n",
"\n",
"# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large')\n",
"whisper_model_name = 'medium.en'"
"whisper_model_name = \"large-v2\"\n",
"\n",
"# replaces numerical digits with their pronounciation, increases diarization accuracy\n",
"suppress_numerals = True"
]
},
{
Expand Down Expand Up @@ -447,15 +510,7 @@
"id": "HKcgQUrAzsJZ",
"outputId": "dc2a1d96-20da-4749-9d64-21edacfba1b1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Source splitting failed, using original audio file. Use --no-stem argument to disable it.\n"
]
}
],
"outputs": [],
"source": [
"if enable_stemming:\n",
" # Isolate vocals from the rest of the audio\n",
Expand All @@ -465,9 +520,7 @@
" )\n",
"\n",
" if return_code != 0:\n",
" logging.warning(\n",
" \"Source splitting failed, using original audio file.\"\n",
" )\n",
" logging.warning(\"Source splitting failed, using original audio file.\")\n",
" vocal_target = audio_path\n",
" else:\n",
" vocal_target = os.path.join(\n",
Expand Down Expand Up @@ -508,9 +561,17 @@
"# model = WhisperModel(model_size, device=\"cuda\", compute_type=\"int8_float16\")\n",
"# or run on CPU with INT8\n",
"# model = WhisperModel(model_size, device=\"cpu\", compute_type=\"int8\")\n",
"if suppress_numerals:\n",
" numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)\n",
"else:\n",
" numeral_symbol_tokens = None\n",
"\n",
"segments, info = whisper_model.transcribe(\n",
" vocal_target, beam_size=1, word_timestamps=True\n",
" vocal_target,\n",
" beam_size=5,\n",
" word_timestamps=True,\n",
" suppress_tokens=numeral_symbol_tokens,\n",
" vad_filter=True,\n",
")\n",
"whisper_results = []\n",
"for segment in segments:\n",
Expand Down Expand Up @@ -550,15 +611,16 @@
" result_aligned = whisperx.align(\n",
" whisper_results, alignment_model, metadata, vocal_target, device\n",
" )\n",
" word_timestamps = result_aligned[\"word_segments\"]\n",
" word_timestamps = filter_missing_timestamps(result_aligned[\"word_segments\"])\n",
"\n",
" # clear gpu vram\n",
" del alignment_model\n",
" torch.cuda.empty_cache()\n",
"else:\n",
" word_timestamps = []\n",
" for segment in whisper_results:\n",
" for word in segment[\"words\"]:\n",
" word_timestamps.append({\"text\": word[2], \"start\": word[0], \"end\": word[1]})"
" word_timestamps.append({\"word\": word[2], \"start\": word[0], \"end\": word[1]})"
]
},
{
Expand All @@ -574,16 +636,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rVPuY1VS0xN7"
},
"metadata": {},
"outputs": [],
"source": [
"signal, sample_rate = librosa.load(vocal_target, sr=None)\n",
"sound = AudioSegment.from_file(vocal_target).set_channels(1)\n",
"ROOT = os.getcwd()\n",
"temp_path = os.path.join(ROOT, \"temp_outputs\")\n",
"os.makedirs(temp_path, exist_ok=True)\n",
"soundfile.write(os.path.join(temp_path, \"mono_file.wav\"), signal, sample_rate, \"PCM_24\")"
"sound.export(os.path.join(temp_path, \"mono_file.wav\"), format=\"wav\")"
]
},
{
Expand Down Expand Up @@ -702,8 +762,6 @@
" word = word.rstrip(\".\")\n",
" word_dict[\"word\"] = word\n",
"\n",
" \n",
"\n",
" wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
"else:\n",
" print(\n",
Expand Down Expand Up @@ -758,7 +816,16 @@
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
31 changes: 25 additions & 6 deletions diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
import logging

mtypes = {'cpu': 'int8', 'cuda': 'float16'}
mtypes = {"cpu": "int8", "cuda": "float16"}

# Initialize parser
parser = argparse.ArgumentParser()
Expand All @@ -26,6 +26,15 @@
"This helps with long files that don't contain a lot of music.",
)

parser.add_argument(
"--suppress_numerals",
action="store_true",
dest="suppress_numerals",
default=False,
help="Suppresses Numerical Digits."
"This helps the diarization accuracy but converts all digits into written text.",
)

parser.add_argument(
"--whisper-model",
dest="model_name",
Expand Down Expand Up @@ -64,15 +73,25 @@

# Run on GPU with FP16
whisper_model = WhisperModel(
args.model_name, device=args.device, compute_type=mtypes[args.device])
args.model_name, device=args.device, compute_type=mtypes[args.device]
)

# or run on GPU with INT8
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# model = WhisperModel(model_size, device="cpu", compute_type="int8")

if args.suppress_numerals:
numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
else:
numeral_symbol_tokens = None

segments, info = whisper_model.transcribe(
vocal_target, beam_size=1, word_timestamps=True
vocal_target,
beam_size=5,
word_timestamps=True,
suppress_tokens=numeral_symbol_tokens,
vad_filter=True,
)
whisper_results = []
for segment in segments:
Expand All @@ -88,15 +107,15 @@
result_aligned = whisperx.align(
whisper_results, alignment_model, metadata, vocal_target, args.device
)
word_timestamps = result_aligned["word_segments"]
word_timestamps = filter_missing_timestamps(result_aligned["word_segments"])
# clear gpu vram
del alignment_model
torch.cuda.empty_cache()
else:
word_timestamps = []
for segment in whisper_results:
for word in segment["words"]:
word_timestamps.append({"text": word[2], "start": word[0], "end": word[1]})
word_timestamps.append({"word": word[2], "start": word[0], "end": word[1]})


# convert audio to mono for NeMo combatibility
Expand Down Expand Up @@ -156,7 +175,7 @@
wsm = get_realigned_ws_mapping_with_punctuation(wsm)
else:
logging.warning(
f'Punctuation restoration is not available for {info.language} language.'
f"Punctuation restoration is not available for {info.language} language."
)

ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
Expand Down
Loading

0 comments on commit 86dea57

Please sign in to comment.