Skip to content

Commit

Permalink
Добавил фронтенд и оптимизировал работу декодера и пунктуатора на бэк…
Browse files Browse the repository at this point in the history
…енде
  • Loading branch information
sxdxfan committed Oct 23, 2021
1 parent aeef248 commit 709bd86
Show file tree
Hide file tree
Showing 145 changed files with 5,989 additions and 59 deletions.
10 changes: 7 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@ def index():

@app.route('/asr', methods=['POST'])
def asr():
host_url = "https://asr-contest.nanosemantics.ai"
res = []
for f in request.files:
if f.startswith('audio_blob') and FileHandler.check_format(request.files[f]):

response_code, filename, response = FileHandler.get_recognized_text(request.files[f])
response_code, audio_file, docx_file, response = FileHandler.get_recognized_text(request.files[f])

if response_code == 0:
response_audio_url = url_for('media_file', filename=filename)
response_audio_url = url_for('media_file', filename=audio_file)
response_docx_url = url_for('media_file', filename=docx_file)
else:
response_audio_url = None
response_docx_url = None

res.append({
'response_audio_url': response_audio_url,
'response_docx_url': host_url + response_docx_url if response_docx_url else '',
'response_audio_url': host_url + response_audio_url if response_audio_url else '',
'response_code': response_code,
'response': response,
})
Expand Down
15 changes: 12 additions & 3 deletions config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ labels = [_-абвгдеёжзийклмнопрстуфхцчшщъыьэюя ]
model_path = data/w2l-16khz.hdf

# Path to language model
lm_path = data/vosk/lm.klm
lm_path = data/lm/lm.klm

# Path to the lexicon file
lexicon = data/vosk/lexicon.txt
lexicon = data/lm/lexicon.txt

# Path to prediction tokens file
tokens = data/tokens.txt
Expand All @@ -32,6 +32,15 @@ window_size = 0.02
# Window stride in seconds for acoustic model samples
window_stride = 0.01

# Voice Activity Detector agressiveness mode (0-3)
vad_aggressiveness_mode = 3

# Voice Activity Detector frame duration in milliseconds
vad_frame_duration_ms = 10

# Voice Activity Detector maximum pause duration in milliseconds
vad_max_pause_ms = 500


[Train]
# Path to train manifest csv
Expand Down Expand Up @@ -62,4 +71,4 @@ checkpoint_per_batch = 1000
save_folder = Checkpoints/

# Continue from checkpoint model
continue_from = data/w2l-16khz.hdf
continue_from = data/w2l-16khz.hdf
5 changes: 2 additions & 3 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ def load_audio(path, sample_rate):
sound = sound.set_channels(1)
sound = sound.set_sample_width(2)

return np.array(sound.get_array_of_samples()).astype(float)
return sound


def preprocess(audio_path, sample_rate=16000, window_size=0.02, window_stride=0.01, window='hamming'):
audio = load_audio(audio_path, sample_rate)
def preprocess(audio, sample_rate=16000, window_size=0.02, window_stride=0.01, window='hamming'):
nfft = int(sample_rate * window_size)
win_length = nfft
hop_length = int(sample_rate * window_stride)
Expand Down
21 changes: 13 additions & 8 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def decode(self, output, start_timestamp=0, frame_time=0.02):


class TrieDecoder:
def __init__(self, lexicon, tokens, lm_path, beam_threshold=30):
def __init__(self, lexicon, tokens, lm_path, beam_threshold=10):
from trie_decoder.common import Dictionary, create_word_dict, load_words
from trie_decoder.decoder import CriterionType, DecoderOptions, KenLM, LexiconDecoder
lexicon = load_words(lexicon)
Expand Down Expand Up @@ -101,12 +101,16 @@ def get_trie(self, lexicon):

return trie, sil_idx, blank_idx, unk_idx

def decode(self, output, start_timestamp=0, frame_time=0.02):
def decode(self, output, start_timestamp=0, frame_time=0.02, max_decoder_len=500):
output = np.log(softmax(output[:, :].astype(np.float32, copy=False), axis=-1))

t, n = output.shape
result = self.trieDecoder.decode(output.ctypes.data, t, n)[0]
tokens = result.tokens
results = []
for i in range(1 + output.shape[0] // max_decoder_len):
output_part = output[i * max_decoder_len:(i + 1) * max_decoder_len]
t, n = output_part.shape
results.append(self.trieDecoder.decode(output_part.ctypes.data, t, n)[0])

tokens = [token for result in results for token in result.tokens]

words, new_word = [], True
current_word, current_timestamp, start_idx, end_idx = None, start_timestamp, 0, 0
Expand Down Expand Up @@ -134,14 +138,15 @@ def decode(self, output, start_timestamp=0, frame_time=0.02):
words_len += end_idx - start_idx
words.append({
"word": current_word,
"start": np.round(current_timestamp, 2),
"timestamp": max(0.0, np.round(current_timestamp, 2) - 0.2),
"end": np.round(end_timestamp, 2),
"confidence": np.round(np.exp(word_lm_score / max(1, end_idx - start_idx)) * 100, 2)
"confidence": np.round(np.exp(word_lm_score / max(10, end_idx - start_idx)) * 100, 2)
})

else:
current_word += self.tokenDict.get_entry(k)

score = np.round(np.exp(result.score / max(1, words_len)), 2)
score = np.mean([result.score for result in results])
score = np.round(np.exp(score / max(1, words_len)), 2)

return DecodeResult(score, words)
31 changes: 31 additions & 0 deletions decoder_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import configparser
from decoder import TrieDecoder
from flask import Flask, request
import json
import numpy as np


config = configparser.ConfigParser()
config.read("config.ini", encoding="UTF-8")
lexicon = config["Wav2Letter"]["lexicon"]
tokens = config["Wav2Letter"]["tokens"]
lm_path = config["Wav2Letter"]["lm_path"]
beam_threshold = float(config["Wav2Letter"]["beam_threshold"])
decoder = TrieDecoder(lexicon, tokens, lm_path, beam_threshold)

app = Flask(__name__)


@app.route("/decode", methods=["POST"])
def decode():
data = request.json
outputs = np.array(data["outputs"])
result = decoder.decode(outputs, start_timestamp=data["start_timestamp"])

results = {
"text": result.text,
"score": result.score,
"words": result.words
}

return json.dumps(results, ensure_ascii=False)
32 changes: 27 additions & 5 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ services:
runtime: nvidia
environment:
- NVIDIA_VISIBLE_DEVICES=all

image: sova-asr:master
volumes:
- .:/sova-asr
command: bash -c "gunicorn --access-logfile - -w 1 --bind 0.0.0.0:8888 app:app --timeout 15000"
ports:
- 8888:8888
command: bash -c "gunicorn --access-logfile - -w 2 --bind 0.0.0.0:8888 app:app --timeout 15000"
network_mode: host

sova-asr-train:
restart: "no"
Expand All @@ -30,4 +28,28 @@ services:
image: sova-asr:master
volumes:
- .:/sova-asr
command: bash -c "python3 train.py"
command: bash -c "python3 train.py"

sova-asr-decoder:
restart: always
container_name: sova-asr-decoder
build:
context: .
dockerfile: Dockerfile
image: sova-asr:master
volumes:
- .:/sova-asr
command: bash -c "gunicorn --access-logfile - -w 4 --bind 0.0.0.0:8889 decoder_app:app --timeout 15000"
network_mode: host

sova-asr-punctuator:
restart: always
container_name: sova-asr-punctuator
build:
context: .
dockerfile: Dockerfile
image: sova-asr:master
volumes:
- .:/sova-asr
command: bash -c "gunicorn --access-logfile - -w 2 --bind 0.0.0.0:8890 punctuator_app:app --timeout 15000"
network_mode: host
116 changes: 101 additions & 15 deletions file_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,103 @@
import time
import logging
import uuid
import json
from speech_recognizer import SpeechRecognizer
from punctuator import Punctuator
from number_utils.text2numbers import TextToNumbers
from docx import Document
from docx.enum.text import WD_COLOR_INDEX
import numpy as np
from multiprocessing.pool import ThreadPool
from requests import request
from datetime import datetime


speech_recognizer = SpeechRecognizer()
punctuator = Punctuator(model_path="data/punctuator")
text2numbers = TextToNumbers()


def write_highlighted_text(text, words, document, threshold=70):
punctuation = ".,?"
numbers = "0123456789"
tokens = text.split(" ")
words_position = 0

paragraph = document.add_paragraph()

for token in tokens:
if len(token) == 0:
continue

if token[-1] in punctuation:
punct = token[-1]
token = token[:-1]
else:
token = token
punct = ""

if token[0] in numbers:
paragraph.add_run(token)
else:
while words_position < len(words) and words[words_position]["word"] != token.lower():
words_position += 1

confidence = words[words_position]["confidence"]

if confidence < threshold:
font = paragraph.add_run(token).font
font.highlight_color = WD_COLOR_INDEX.YELLOW
else:
paragraph.add_run(token)

paragraph.add_run(punct)
if punct == ".":
paragraph.add_run("\n")
else:
paragraph.add_run(" ")


def punctuator_request(text):
url = "http://localhost:8890/predict"
data = {
"text": text
}

response = request("POST", url, json=data)

result = json.loads(response.text)

return result


class FileHandler:
@staticmethod
def get_recognized_text(blob):
try:
filename = str(uuid.uuid4())
# filename = str(uuid.uuid4())
ts = time.time()
filename = (os.path.splitext(blob.filename)[0]).replace(" ", "_") + str(datetime.utcfromtimestamp(ts).strftime('_%Y-%m-%d_%H:%M:%S'))
os.makedirs('./records', exist_ok=True)
new_record_path = os.path.join('./records', filename + '.webm')
blob.save(new_record_path)
new_filename = filename + '.wav'
converted_record_path = FileHandler.convert_to_wav(new_record_path, new_filename)
audio_file = filename + '.wav'
converted_record_path = FileHandler.convert_to_wav(new_record_path, audio_file)
response_models_result = FileHandler.get_models_result(converted_record_path)
return 0, new_filename, response_models_result

document = Document()
document.add_heading('Протокол конференции', level=1)
for result in response_models_result:
text = result.get('text')
words = result.get('words')
write_highlighted_text(text, words, document)

docx_file = filename + '.docx'
document.save(f'./records/{docx_file}')

return 0, audio_file, docx_file, response_models_result

except Exception as e:
logging.exception(e)
return 1, None, str(e)
return 1, None, None, str(e)

@staticmethod
def convert_to_wav(webm_full_filepath, new_filename):
Expand All @@ -52,19 +124,33 @@ def check_format(files):
return True

@staticmethod
def get_models_result(converted_record_path, delimiter='<br>'):
results = []
def get_models_result(converted_record_path):
start = time.time()
decoder_result = speech_recognizer.recognize(converted_record_path)
text = punctuator.predict(decoder_result.text)
text = text2numbers.convert(text)

results = []
decoder_results = speech_recognizer.recognize(converted_record_path)

score = np.mean([result.get("score") for result in decoder_results])
words = [w for result in decoder_results for w in result.get("words")]

texts = [result.get("text") for result in decoder_results]

texts = [text2numbers.convert(text) for text in [
" ".join(texts[:len(texts) // 2]),
" ".join(texts[len(texts) // 2:])
]]

pool = ThreadPool(processes=2)
texts = pool.map(punctuator_request, texts)
text = " ".join(texts)

end = time.time()
results.append(
{
'text': text,
'text': text.strip(),
'time': round(end - start, 3),
'confidence': decoder_result.score,
'words': decoder_result.words
'confidence': score,
'words': words
}
)
return results
Loading

0 comments on commit 709bd86

Please sign in to comment.