Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'run_batch' mode for GPU encoding and decoding with batch_size >= 1 #1534

Open
wants to merge 75 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
33910f8
add concurrent performance testing for websocket_server_main (non-str…
veelion Jul 20, 2022
90c1d75
Merge branch 'main' of https://github.com/wenet-e2e/wenet
veelion Jul 20, 2022
4d52c27
Merge branch 'main' of https://github.com/wenet-e2e/wenet
veelion Jul 22, 2022
4f573b5
Merge branch 'main' of https://github.com/wenet-e2e/wenet
veelion Jul 28, 2022
fda37b7
Merge branch 'main' of https://github.com/wenet-e2e/wenet
veelion Aug 2, 2022
3f199f5
add batch processing to decoder
veelion Aug 5, 2022
455c870
add api_batch_main
veelion Aug 5, 2022
a67bc93
add BatchRecognizer to api
veelion Aug 5, 2022
44501ae
add batch_model pointer to DecodeResource
veelion Aug 5, 2022
7aabd7b
add batch processing source to decoder_srcs
veelion Aug 5, 2022
7a28701
jit export forward_encoder_batch()
veelion Aug 5, 2022
19114ee
add batch processing to Python binding
veelion Aug 5, 2022
7fe373a
Merge branch 'main' of https://github.com/wenet-e2e/wenet into batch
veelion Aug 10, 2022
0fd6c89
before change attention-scoring
veelion Aug 17, 2022
44392e6
add multi-threads for computing fbank, ctc searching
veelion Aug 18, 2022
e1e597e
to call jit script which support batch_forward_attention_decoder()
veelion Aug 18, 2022
c25afd8
add run_batch flag to support BatchTorchAsrModel
veelion Aug 18, 2022
63c42f4
replace UpdateResult with decoder's get_batch_result()
veelion Aug 18, 2022
238fc8e
add FLAGS_enable_timestamp
veelion Aug 18, 2022
3576cb4
add FLAGS_run_batch for runing for batch decoding
veelion Aug 18, 2022
c3a17b1
fix: https://github.com/nbsdx/SimpleJSON/issues/4
veelion Aug 18, 2022
7adcd74
support run_batch
veelion Aug 18, 2022
a3045fc
add batch_connection_handler.h
veelion Aug 19, 2022
7bc634a
jit export batch_forward_attention_decoder()
veelion Aug 19, 2022
35e8d1a
add to decoder_srcs with batch_torch_asr_model.cc, batch_onnx_asr_mod…
veelion Aug 25, 2022
d172dd3
remove log msg
veelion Aug 25, 2022
d710a54
add is_fp16_
veelion Aug 25, 2022
1329090
add is_fp16 to Read()
veelion Aug 25, 2022
3aa76c7
add BatchOnnxAsrModel on GPU
veelion Aug 25, 2022
816a13a
add Yaml reader
veelion Aug 25, 2022
72d979b
Merge branch 'wenet-e2e:main' into main
veelion Aug 26, 2022
85e5ac5
export onnx gpu model for c++ runtime
veelion Aug 29, 2022
e9ffa53
improve memory managing if is_fp16
veelion Aug 30, 2022
5cb3c07
replace Eigen::half with <immintrin.h>
veelion Aug 30, 2022
48602b7
let model caculate attention score
veelion Aug 31, 2022
a71815d
let decoder return score
veelion Aug 31, 2022
2bfc65c
let attention decoder return score
veelion Aug 31, 2022
0544b95
remove lock if only one wav
veelion Sep 1, 2022
ed673b0
add gpu_id flag for BatchOnnxAsrModel
veelion Sep 1, 2022
d9d676c
add gpu_id flag for BatchOnnxAsrModel
veelion Sep 1, 2022
6afce87
fix memory issue of CreateTensor()
veelion Sep 1, 2022
8ea5a3e
add decoder_main_batch
veelion Sep 1, 2022
b3d82aa
re-link
veelion Sep 1, 2022
b6d5cdf
config cudnn_conv1d_pad_to_nc1d
veelion Sep 1, 2022
38e983d
use encoder's out : beam_log_probs/index (topk) to ctc_search, which …
veelion Sep 1, 2022
fcacb20
make ForwardEncoder() output topk
veelion Sep 2, 2022
5b28502
make batch_forward_encoder() to return topk ctc_log_probs
veelion Sep 2, 2022
2ba32c3
only emptyCache() if USE_GPU
veelion Oct 21, 2022
e2259db
supprot GPU
veelion Oct 25, 2022
457d9b0
add more pytorch version
veelion Oct 25, 2022
14ffae6
save eos, sos to onnx_config for onnxruntime of C++
veelion Oct 25, 2022
1e9faaf
transformer decoder has no 'reverse_weight' in confi
veelion Oct 25, 2022
a73b792
fix rescore_inputs
veelion Oct 25, 2022
65eb608
release GPU memory
veelion Oct 26, 2022
7d1700a
add onnx_version 1.13.1
veelion Nov 3, 2022
14d6cd5
replace GetInputName() with GetInputNameAllocated(), becaue GetInputN…
veelion Nov 3, 2022
65a88f9
Merge branch 'main' of https://github.com/wenet-e2e/wenet
veelion Nov 3, 2022
cf50ad0
Merge branch 'wenet-e2e:main' into main
veelion Nov 3, 2022
c64906a
merge
veelion Nov 3, 2022
4ae1c65
add description of 'run_batch' mode
veelion Nov 3, 2022
a0fb171
Merge run_batch mode to main branch
veelion Nov 3, 2022
068e4a7
fix batch_size
veelion Nov 3, 2022
aa1ac47
notes for a little bigger CER
veelion Nov 3, 2022
ba93bd9
remove trailing whitespace
veelion Nov 3, 2022
bbc7c15
fix flake8 error
veelion Nov 3, 2022
fb9e436
fix cpplint error
veelion Nov 4, 2022
cd85c84
fix flake8 error
veelion Nov 4, 2022
e8a0a24
pytorch version back to 1.10.0
veelion Nov 4, 2022
1e2af87
change reference to pointer of non-const object
veelion Nov 4, 2022
acea6da
fix github action build error
veelion Nov 4, 2022
8a7ac0a
Merge branch 'main' into main
veelion Nov 4, 2022
b61fae1
Merge branch 'main' of https://github.com/wenet-e2e/wenet
veelion Nov 7, 2022
8fda52a
supported GPU-compute feature(fbank) by kaldifeat
veelion Nov 30, 2022
e0b4e42
add fbank_cuda.h
veelion Nov 30, 2022
f3e2aee
Merge branch 'main' into vee-main
veelion Nov 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions runtime/binding/python/cpp/binding.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) 2022 Binbin Zhang([email protected])
// 2022 SoundDataConverge Co.LTD (Weiliang Chong)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -13,8 +14,10 @@
// limitations under the License.

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "api/wenet_api.h"
#include "api/batch_recognizer.h"

namespace py = pybind11;

Expand All @@ -37,4 +40,12 @@ PYBIND11_MODULE(_wenet, m) {
m.def("wenet_set_language", &wenet_set_language, "set language");
m.def("wenet_set_continuous_decoding", &wenet_set_continuous_decoding,
"enable continuous decoding or not");
py::class_<BatchRecognizer>(m, "BatchRecognizer")
.def(py::init<const char*>())
.def("set_enable_timestamp", &BatchRecognizer::set_enable_timestamp)
.def("AddContext", &BatchRecognizer::AddContext)
.def("set_context_score", &BatchRecognizer::set_context_score)
.def("set_language", &BatchRecognizer::set_language)
.def("DecodeData", &BatchRecognizer::DecodeData)
.def("Decode", &BatchRecognizer::Decode);
}
1 change: 1 addition & 0 deletions runtime/binding/python/py/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .decoder import Decoder # noqa
from .batch_decoder import BatchDecoder # noqa
from _wenet import wenet_set_log_level as set_log_level # noqa
79 changes: 79 additions & 0 deletions runtime/binding/python/py/batch_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2022 Binbin Zhang([email protected])
# 2022 SoundDataConverge Co.LTD (Weiliang Chong)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import _wenet

from .hub import Hub


class BatchDecoder:

def __init__(self,
model_dir: Optional[str] = None,
lang: str = 'chs',
nbest: int = 1,
enable_timestamp: bool = False,
context: Optional[List[str]] = None,
context_score: float = 3.0):
""" Init WeNet decoder
Args:
lang: language type of the model
nbest: nbest number for the final result
enable_timestamp: whether to enable word level timestamp
for the final result
context: context words
context_score: bonus score when the context is matched
"""
if model_dir is None:
model_dir = Hub.get_model_by_lang(lang)

self.d = _wenet.BatchRecognizer(model_dir)

self.set_language(lang)
self.enable_timestamp(enable_timestamp)
if context is not None:
self.add_context(context)
self.set_context_score(context_score)

def __del__(self):
del self.d

def enable_timestamp(self, flag: bool):
tag = 1 if flag else 0
self.d.set_enable_timestamp(tag)

def add_context(self, contexts: List[str]):
for c in contexts:
assert isinstance(c, str)
self.d.AddContext(c)

def set_context_score(self, score: float):
self.d.set_context_score(score)

def set_language(self, lang: str):
assert lang in ['chs', 'en']
self.d.set_language(lang)

def decode(self, pcms: List[bytes]) -> str:
""" Decode the input data

Args:
pcms: a list of wav pcm
"""
assert isinstance(pcms[0], bytes)
result = self.d.Decode(pcms)
return result
148 changes: 148 additions & 0 deletions runtime/core/api/batch_recognizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright (c) 2022 Binbin Zhang ([email protected])
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef API_BATCH_RECOGNIZER_H_
#define API_BATCH_RECOGNIZER_H_

#include <memory>
#include <string>
#include <vector>
#include <utility>

#include "decoder/asr_decoder.h"
#include "decoder/batch_asr_decoder.h"
#include "decoder/batch_torch_asr_model.h"
#include "post_processor/post_processor.h"
#include "utils/file.h"
#include "utils/json.h"
#include "utils/string.h"

class BatchRecognizer {
public:
explicit BatchRecognizer(const std::string& model_dir, int num_threads = 1) {
// FeaturePipeline init
feature_config_ = std::make_shared<wenet::FeaturePipelineConfig>(80, 16000);
// Resource init
resource_ = std::make_shared<wenet::DecodeResource>();
wenet::BatchTorchAsrModel::InitEngineThreads(num_threads);
std::string model_path = wenet::JoinPath(model_dir, "final.zip");
CHECK(wenet::FileExists(model_path));

auto model = std::make_shared<wenet::BatchTorchAsrModel>();
model->Read(model_path);
resource_->batch_model = model;

// units.txt: E2E model unit
std::string unit_path = wenet::JoinPath(model_dir, "units.txt");
CHECK(wenet::FileExists(unit_path));
resource_->unit_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(unit_path));

std::string fst_path = wenet::JoinPath(model_dir, "TLG.fst");
if (wenet::FileExists(fst_path)) { // With LM
resource_->fst = std::shared_ptr<fst::Fst<fst::StdArc>>(
fst::Fst<fst::StdArc>::Read(fst_path));

std::string symbol_path = wenet::JoinPath(model_dir, "words.txt");
CHECK(wenet::FileExists(symbol_path));
resource_->symbol_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(symbol_path));
} else { // Without LM, symbol_table is the same as unit_table
resource_->symbol_table = resource_->unit_table;
}

// Context config init
context_config_ = std::make_shared<wenet::ContextConfig>();
decode_options_ = std::make_shared<wenet::DecodeOptions>();
post_process_opts_ = std::make_shared<wenet::PostProcessOptions>();
}

void InitDecoder() {
CHECK(decoder_ == nullptr);
// Optional init context graph
if (context_.size() > 0) {
context_config_->context_score = context_score_;
auto context_graph =
std::make_shared<wenet::ContextGraph>(*context_config_);
context_graph->BuildContextGraph(context_, resource_->symbol_table);
resource_->context_graph = context_graph;
}
// PostProcessor
if (language_ == "chs") { // TODO(Binbin Zhang): CJK(chs, jp, kr)
post_process_opts_->language_type = wenet::kMandarinEnglish;
} else {
post_process_opts_->language_type = wenet::kIndoEuropean;
}
resource_->post_processor =
std::make_shared<wenet::PostProcessor>(*post_process_opts_);
// Init decoder
decoder_ = std::make_shared<wenet::BatchAsrDecoder>(
feature_config_, resource_,
*decode_options_);
}

std::string Decode(const std::vector<std::string>& wavs) {
// Init decoder when it is called first time
if (decoder_ == nullptr) {
InitDecoder();
}
std::vector<std::vector<float>> wavs_float;
for (auto& wav : wavs) {
const int16_t* pcm = reinterpret_cast<const int16_t*>(wav.data());
int pcm_len = wav.size() / sizeof(int16_t);
std::vector<float> wav_float(pcm_len);
for (size_t i = 0; i < pcm_len; i++) {
wav_float[i] = static_cast<float>(*(pcm + i));
}
wavs_float.push_back(std::move(wav_float));
}
decoder_->Reset();
decoder_->Decode(wavs_float);
return decoder_->get_batch_result(nbest_, enable_timestamp_);
}

std::string DecodeData(const std::vector<std::vector<float>>& wavs) {
// Init decoder when it is called first time
if (decoder_ == nullptr) {
InitDecoder();
}
decoder_->Reset();
decoder_->Decode(wavs);
return decoder_->get_batch_result(nbest_, enable_timestamp_);
}



void set_nbest(int n) { nbest_ = n; }
void set_enable_timestamp(bool flag) { enable_timestamp_ = flag; }
void AddContext(const char* word) { context_.emplace_back(word); }
void set_context_score(float score) { context_score_ = score; }
void set_language(const char* lang) { language_ = lang; }

private:
std::shared_ptr<wenet::FeaturePipelineConfig> feature_config_ = nullptr;
std::shared_ptr<wenet::DecodeResource> resource_ = nullptr;
std::shared_ptr<wenet::DecodeOptions> decode_options_ = nullptr;
std::shared_ptr<wenet::BatchAsrDecoder> decoder_ = nullptr;
std::shared_ptr<wenet::ContextConfig> context_config_ = nullptr;
std::shared_ptr<wenet::PostProcessOptions> post_process_opts_ = nullptr;

int nbest_ = 1;
bool enable_timestamp_ = false;
std::vector<std::string> context_;
float context_score_;
std::string language_ = "chs";
};

#endif // API_BATCH_RECOGNIZER_H_
3 changes: 3 additions & 0 deletions runtime/core/bin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
add_executable(decoder_main decoder_main.cc)
target_link_libraries(decoder_main PUBLIC decoder)

add_executable(decoder_main_batch decoder_main_batch.cc)
target_link_libraries(decoder_main_batch PUBLIC decoder)

add_executable(label_checker_main label_checker_main.cc)
target_link_libraries(label_checker_main PUBLIC decoder)

Expand Down
51 changes: 51 additions & 0 deletions runtime/core/bin/api_batch_main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2022 Binbin Zhang ([email protected])
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "api/batch_recognizer.h"
#include "api/wenet_api.h"
#include "frontend/wav.h"
#include "utils/flags.h"
#include "utils/timer.h"

DEFINE_string(model_dir, "", "model dir path");
DEFINE_string(wav_path, "", "single wave path");
DEFINE_int32(batch_size, 1, "batch size of input");
DEFINE_int32(num_threads, 1, "number threads of intraop");
DEFINE_bool(enable_timestamp, false, "enable timestamps");

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);

wenet_set_log_level(2);

BatchRecognizer br(FLAGS_model_dir, FLAGS_num_threads);
if (FLAGS_enable_timestamp) br.set_enable_timestamp(true);
wenet::WavReader wav_reader(FLAGS_wav_path);
std::vector<float> data;
data.insert(
data.end(), wav_reader.data(),
wav_reader.data() + wav_reader.num_samples());
std::vector<std::vector<float>> wavs;
for (size_t i = 0; i < FLAGS_batch_size - 1; i++) {
wavs.push_back(data);
}
wavs.push_back(std::move(data));
wenet::Timer timer;
std::string result = br.DecodeData(wavs);
int forward_time = timer.Elapsed();
VLOG(1) << "Decode() takes " << forward_time << " ms";
LOG(INFO) << result;
return 0;
}
Loading