From d5484963c7dcfc4b96a77cc2f634c01231c895f3 Mon Sep 17 00:00:00 2001 From: cdliang11 <1404056823@qq.com> Date: Thu, 6 Feb 2025 14:13:40 +0800 Subject: [PATCH 1/4] [lint] use pre-commit to auto check and lint --- .github/workflows/lint.yml | 2 +- .pre-commit-config.yaml | 21 +++++++++++++++++++++ README.md | 1 + requirements.txt | 1 + 4 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index bf0cac9..1459fec 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,7 +35,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v1 with: - python-version: 3.x + python-version: 3.9 architecture: x64 - name: Fetch Wenet uses: actions/checkout@v1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..bee2fc6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - repo: https://github.com/pre-commit/mirrors-yapf + rev: 'v0.32.0' + hooks: + - id: yapf + - repo: https://github.com/pycqa/flake8 + rev: '3.8.2' + hooks: + - id: flake8 + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: 'v17.0.6' + hooks: + - id: clang-format + - repo: https://github.com/cpplint/cpplint + rev: '1.6.1' + hooks: + - id: cpplint diff --git a/README.md b/README.md index fa2fd30..61b20f1 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ pip install torch torchaudio ``` sh pip install -r requirements.txt +pre-commit install # for clean and tidy code ``` ## Dataset diff --git a/requirements.txt b/requirements.txt index 1b72a96..496d789 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ pyflakes==2.2.0 lmdb scipy tqdm +pre-commit==3.5.0 From 363f68b7c629262ee5571a096b6a144403d36f12 Mon Sep 17 00:00:00 2001 From: cdliang11 <1404056823@qq.com> Date: Thu, 6 Feb 2025 14:33:10 +0800 Subject: [PATCH 2/4] [lint] auto format all by pre-commit, including c++, python --- .pre-commit-config.yaml | 2 + docs/conf.py | 4 - runtime/android/app/src/main/cpp/wekws.cc | 22 ++-- runtime/core/bin/kws_main.cc | 1 - runtime/core/frontend/feature_pipeline.h | 2 +- runtime/core/kws/keyword_spotting.cc | 34 +++-- runtime/core/kws/keyword_spotting.h | 2 - runtime/core/utils/log.h | 54 ++++---- tools/compute_cmvn_stats.py | 5 +- tools/make_list.py | 35 ++--- tools/wav2dur.py | 1 - wekws/bin/compute_det.py | 11 +- wekws/bin/compute_det_ctc.py | 65 +++++---- wekws/bin/export_onnx.py | 9 +- wekws/bin/plot_det_curve.py | 56 +++----- wekws/bin/score.py | 8 +- wekws/bin/score_ctc.py | 26 ++-- wekws/bin/stream_kws_ctc.py | 153 ++++++++++++---------- wekws/bin/stream_score_ctc.py | 55 ++++---- wekws/bin/train.py | 8 +- wekws/dataset/dataset.py | 14 +- wekws/dataset/processor.py | 24 ++-- wekws/model/classifier.py | 6 + wekws/model/cmvn.py | 1 + wekws/model/fsmn.py | 78 +++++------ wekws/model/kws_model.py | 11 +- wekws/model/loss.py | 47 ++++--- wekws/model/mdtc.py | 7 +- wekws/model/subsampling.py | 4 + wekws/model/tcn.py | 4 + wekws/utils/checkpoint.py | 1 - wekws/utils/cmvn.py | 1 + wekws/utils/executor.py | 10 +- 33 files changed, 412 insertions(+), 349 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bee2fc6..2e4cd40 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,6 +15,8 @@ repos: rev: 'v17.0.6' hooks: - id: clang-format + args: ['--style=file'] + exclude: 'runtime/android/app/src/.*\.(json|java|js|m|mm|proto)' - repo: https://github.com/cpplint/cpplint rev: '1.6.1' hooks: diff --git a/docs/conf.py b/docs/conf.py index 49abc10..3bc4256 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,14 +14,12 @@ # import sys # sys.path.insert(0, os.path.abspath('.')) - # -- Project information ----------------------------------------------------- project = 'Wenet' copyright = '2020, wenet-team' author = 'wenet-team' - # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be @@ -43,7 +41,6 @@ # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] - # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: source_suffix = { @@ -57,7 +54,6 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for diff --git a/runtime/android/app/src/main/cpp/wekws.cc b/runtime/android/app/src/main/cpp/wekws.cc index 18ea8e6..db6feca 100644 --- a/runtime/android/app/src/main/cpp/wekws.cc +++ b/runtime/android/app/src/main/cpp/wekws.cc @@ -82,18 +82,18 @@ void set_input_finished() { // } void start_spot() { - std::vector> feats; - feature_pipeline->Read(80, &feats); - std::vector> prob; - spotter->Forward(feats, &prob); - float max_prob = 0.0; - for (int t = 0; t < prob.size(); t++) { - for (int j = 0; j < prob[t].size(); j++) { - max_prob = std::max(prob[t][j], max_prob); - } + std::vector> feats; + feature_pipeline->Read(80, &feats); + std::vector> prob; + spotter->Forward(feats, &prob); + float max_prob = 0.0; + for (int t = 0; t < prob.size(); t++) { + for (int j = 0; j < prob[t].size(); j++) { + max_prob = std::max(prob[t][j], max_prob); } - result = std::to_string(offset) + " prob: " + std::to_string(max_prob); - offset += prob.size(); + } + result = std::to_string(offset) + " prob: " + std::to_string(max_prob); + offset += prob.size(); } jstring get_result(JNIEnv* env, jobject) { diff --git a/runtime/core/bin/kws_main.cc b/runtime/core/bin/kws_main.cc index 16e5bb5..f627959 100644 --- a/runtime/core/bin/kws_main.cc +++ b/runtime/core/bin/kws_main.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include #include diff --git a/runtime/core/frontend/feature_pipeline.h b/runtime/core/frontend/feature_pipeline.h index 3fdafa6..884b91f 100644 --- a/runtime/core/frontend/feature_pipeline.h +++ b/runtime/core/frontend/feature_pipeline.h @@ -21,8 +21,8 @@ #include #include "frontend/fbank.h" -#include "utils/log.h" #include "utils/blocking_queue.h" +#include "utils/log.h" namespace wenet { diff --git a/runtime/core/kws/keyword_spotting.cc b/runtime/core/kws/keyword_spotting.cc index cf1df8b..4fff6a1 100644 --- a/runtime/core/kws/keyword_spotting.cc +++ b/runtime/core/kws/keyword_spotting.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "kws/keyword_spotting.h" #include @@ -35,30 +34,27 @@ KeywordSpotting::KeywordSpotting(const std::string& model_path) { out_names_ = {"output", "r_cache"}; auto metadata = session_->GetModelMetadata(); Ort::AllocatorWithDefaultOptions allocator; - cache_dim_ = std::stoi(metadata.LookupCustomMetadataMap("cache_dim", - allocator)); - cache_len_ = std::stoi(metadata.LookupCustomMetadataMap("cache_len", - allocator)); + cache_dim_ = + std::stoi(metadata.LookupCustomMetadataMap("cache_dim", allocator)); + cache_len_ = + std::stoi(metadata.LookupCustomMetadataMap("cache_len", allocator)); std::cout << "Kws Model Info:" << std::endl << "\tcache_dim: " << cache_dim_ << std::endl << "\tcache_len: " << cache_len_ << std::endl; Reset(); } - void KeywordSpotting::Reset() { Ort::MemoryInfo memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); cache_.resize(cache_dim_ * cache_len_, 0.0); const int64_t cache_shape[] = {1, cache_dim_, cache_len_}; - cache_ort_ = Ort::Value::CreateTensor( - memory_info, cache_.data(), cache_.size(), cache_shape, 3); + cache_ort_ = Ort::Value::CreateTensor(memory_info, cache_.data(), + cache_.size(), cache_shape, 3); } - -void KeywordSpotting::Forward( - const std::vector>& feats, - std::vector>* prob) { +void KeywordSpotting::Forward(const std::vector>& feats, + std::vector>* prob) { prob->clear(); if (feats.size() == 0) return; Ort::MemoryInfo memory_info = @@ -78,9 +74,9 @@ void KeywordSpotting::Forward( inputs.emplace_back(std::move(feats_ort)); inputs.emplace_back(std::move(cache_ort_)); // ort_outputs.size() == 2 - std::vector ort_outputs = session_->Run( - Ort::RunOptions{nullptr}, in_names_.data(), inputs.data(), - inputs.size(), out_names_.data(), out_names_.size()); + std::vector ort_outputs = + session_->Run(Ort::RunOptions{nullptr}, in_names_.data(), inputs.data(), + inputs.size(), out_names_.data(), out_names_.size()); // 3. Update cache cache_ort_ = std::move(ort_outputs[1]); @@ -92,9 +88,9 @@ void KeywordSpotting::Forward( int output_dim = type_info.GetShape()[2]; prob->resize(num_outputs); for (int i = 0; i < num_outputs; i++) { - (*prob)[i].resize(output_dim); - memcpy((*prob)[i].data(), data + i * output_dim, - sizeof(float) * output_dim); + (*prob)[i].resize(output_dim); + memcpy((*prob)[i].data(), data + i * output_dim, + sizeof(float) * output_dim); } } diff --git a/runtime/core/kws/keyword_spotting.h b/runtime/core/kws/keyword_spotting.h index 14bf732..1b34a55 100644 --- a/runtime/core/kws/keyword_spotting.h +++ b/runtime/core/kws/keyword_spotting.h @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #ifndef KWS_KEYWORD_SPOTTING_H_ #define KWS_KEYWORD_SPOTTING_H_ @@ -55,7 +54,6 @@ class KeywordSpotting { std::vector cache_; }; - } // namespace wekws #endif // KWS_KEYWORD_SPOTTING_H_ diff --git a/runtime/core/utils/log.h b/runtime/core/utils/log.h index 9d7601c..e893e3c 100644 --- a/runtime/core/utils/log.h +++ b/runtime/core/utils/log.h @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #ifndef UTILS_LOG_H_ #define UTILS_LOG_H_ @@ -30,21 +29,21 @@ class Logger { Logger(int severity, const char* func, const char* file, int line) { severity_ = severity; switch (severity) { - case INFO: - ss_ << "INFO ("; - break; - case WARNING: - ss_ << "WARNING ("; - break; - case ERROR: - ss_ << "ERROR ("; - break; - case FATAL: - ss_ << "FATAL ("; - break; - default: - severity_ = FATAL; - ss_ << "FATAL ("; + case INFO: + ss_ << "INFO ("; + break; + case WARNING: + ss_ << "WARNING ("; + break; + case ERROR: + ss_ << "ERROR ("; + break; + case FATAL: + ss_ << "FATAL ("; + break; + default: + severity_ = FATAL; + ss_ << "FATAL ("; } ss_ << func << "():" << file << ':' << line << ") "; } @@ -56,7 +55,8 @@ class Logger { } } - template Logger& operator<<(const T &val) { + template + Logger& operator<<(const T& val) { ss_ << val; return *this; } @@ -66,17 +66,17 @@ class Logger { std::ostringstream ss_; }; -#define LOG(severity) ::wenet::Logger( \ - ::wenet::severity, __func__, __FILE__, __LINE__) +#define LOG(severity) \ + ::wenet::Logger(::wenet::severity, __func__, __FILE__, __LINE__) -#define CHECK(test) \ -do { \ - if (!(test)) { \ - std::cerr << "CHECK (" << __func__ << "():" << __FILE__ << ":" \ - << __LINE__ << ") " << #test << std::endl; \ - exit(-1); \ - } \ -} while (0) +#define CHECK(test) \ + do { \ + if (!(test)) { \ + std::cerr << "CHECK (" << __func__ << "():" << __FILE__ << ":" \ + << __LINE__ << ") " << #test << std::endl; \ + exit(-1); \ + } \ + } while (0) } // namespace wenet diff --git a/tools/compute_cmvn_stats.py b/tools/compute_cmvn_stats.py index 049a057..d8b716f 100755 --- a/tools/compute_cmvn_stats.py +++ b/tools/compute_cmvn_stats.py @@ -16,6 +16,7 @@ class CollateFunc(object): ''' Collate function for AudioDataset ''' + def __init__(self, feat_dim, feat_type, resample_rate): self.feat_dim = feat_dim self.resample_rate = resample_rate @@ -30,8 +31,7 @@ def __call__(self, batch): value = item[1].strip().split(",") assert len(value) == 3 or len(value) == 1 wav_path = value[0] - sample_rate = torchaudio.info(wav_path, - backend='sox').sample_rate + sample_rate = torchaudio.info(wav_path, backend='sox').sample_rate resample_rate = sample_rate # len(value) == 3 means segmented wav.scp, # len(value) == 1 means original wav.scp @@ -73,6 +73,7 @@ def __call__(self, batch): class AudioDataset(Dataset): + def __init__(self, data_file): self.items = [] with codecs.open(data_file, 'r', encoding='utf-8') as f: diff --git a/tools/make_list.py b/tools/make_list.py index 76b55cf..f8c436e 100755 --- a/tools/make_list.py +++ b/tools/make_list.py @@ -22,6 +22,7 @@ symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' + def split_mixed_label(input_str): tokens = [] s = input_str.lower() @@ -35,6 +36,7 @@ def split_mixed_label(input_str): s = s.replace(word, '', 1).strip(' ') return tokens + def query_token_set(txt, symbol_table, lexicon_table): tokens_str = tuple() tokens_idx = tuple() @@ -75,12 +77,10 @@ def query_token_set(txt, symbol_table, lexicon_table): else: if '' in symbol_table: tokens_idx = tokens_idx + (symbol_table[''], ) - logging.info( - f'{ch} is not in token set, replace with ') + logging.info(f'{ch} is not in token set, replace with ') else: tokens_idx = tokens_idx + (symbol_table[''], ) - logging.info( - f'{ch} is not in token set, replace with ') + logging.info(f'{ch} is not in token set, replace with ') return tokens_str, tokens_idx @@ -125,15 +125,14 @@ def query_token_list(txt, symbol_table, lexicon_table): else: if '' in symbol_table: tokens_idx.append(symbol_table['']) - logging.info( - f'{ch} is not in token set, replace with ') + logging.info(f'{ch} is not in token set, replace with ') else: tokens_idx.append(symbol_table['']) - logging.info( - f'{ch} is not in token set, replace with ') + logging.info(f'{ch} is not in token set, replace with ') return tokens_str, tokens_idx + def read_token(token_file): tokens_table = {} with open(token_file, 'r', encoding='utf8') as fin: @@ -162,9 +161,13 @@ def read_lexicon(lexicon_file): parser.add_argument('text_file', help='text file') parser.add_argument('duration_file', help='duration file') parser.add_argument('output_file', help='output list file') - parser.add_argument('--token_file', type=str, default=None, + parser.add_argument('--token_file', + type=str, + default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, + parser.add_argument('--lexicon_file', + type=str, + default=None, help='the path of lexicon.txt') args = parser.parse_args() @@ -195,13 +198,12 @@ def read_lexicon(lexicon_file): arr = line.strip().split(maxsplit=1) key = arr[0] tokens = None - if token_table is not None and lexicon_table is not None : + if token_table is not None and lexicon_table is not None: if len(arr) < 2: # for some utterence, no text txt = [1] # the /sil is indexed by 1 tokens = ["sil"] else: - tokens, txt = query_token_list(arr[1], - token_table, + tokens, txt = query_token_list(arr[1], token_table, lexicon_table) else: txt = int(arr[1]) @@ -212,8 +214,11 @@ def read_lexicon(lexicon_file): if tokens is None: line = dict(key=key, txt=txt, duration=duration, wav=wav) else: - line = dict(key=key, tok=tokens, txt=txt, - duration=duration, wav=wav) + line = dict(key=key, + tok=tokens, + txt=txt, + duration=duration, + wav=wav) json_line = json.dumps(line, ensure_ascii=False) fout.write(json_line + '\n') diff --git a/tools/wav2dur.py b/tools/wav2dur.py index d416b1a..2961149 100755 --- a/tools/wav2dur.py +++ b/tools/wav2dur.py @@ -5,7 +5,6 @@ import torchaudio - scp = sys.argv[1] dur_scp = sys.argv[2] diff --git a/wekws/bin/compute_det.py b/wekws/bin/compute_det.py index 32b0280..5dfd957 100644 --- a/wekws/bin/compute_det.py +++ b/wekws/bin/compute_det.py @@ -56,10 +56,15 @@ def load_label_and_score(keyword, label_file, score_file): parser.add_argument('--test_data', required=True, help='label file') parser.add_argument('--keyword', type=int, default=0, help='keyword label') parser.add_argument('--score_file', required=True, help='score file') - parser.add_argument('--step', type=float, default=0.01, + parser.add_argument('--step', + type=float, + default=0.01, help='threshold step') - parser.add_argument('--window_shift', type=int, default=50, - help='window_shift is used to skip the frames after triggered') + parser.add_argument( + '--window_shift', + type=int, + default=50, + help='window_shift is used to skip the frames after triggered') parser.add_argument('--stats_file', required=True, help='false reject/alarm stats file') diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py index 9a8b5e7..c7b1a38 100644 --- a/wekws/bin/compute_det_ctc.py +++ b/wekws/bin/compute_det_ctc.py @@ -25,6 +25,7 @@ import pypinyin # for Chinese Character from tools.make_list import query_token_set, read_lexicon, read_token + def split_mixed_label(input_str): tokens = [] s = input_str.lower() @@ -44,6 +45,7 @@ def space_mixed_label(input_str): space_str = ''.join(f'{sub} ' for sub in splits) return space_str.strip() + def load_label_and_score(keywords_list, label_file, score_file, true_keywords): score_table = {} with open(score_file, 'r', encoding='utf8') as fin: @@ -85,7 +87,7 @@ def load_label_and_score(keywords_list, label_file, score_file, true_keywords): for obj in label_lists: assert 'key' in obj assert 'wav' in obj - assert 'tok' in obj # here we use the tokens + assert 'tok' in obj # here we use the tokens assert 'duration' in obj key = obj['key'] @@ -120,6 +122,7 @@ def load_label_and_score(keywords_list, label_file, score_file, true_keywords): return keyword_filler_table + def load_stats_file(stats_file): values = [] with open(stats_file, 'r', encoding='utf8') as fin: @@ -130,6 +133,7 @@ def load_stats_file(stats_file): values.reverse() return np.array(values) + def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5): det_title = "DetCurve" plt.figure(dpi=200) @@ -154,41 +158,50 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5): plt.legend(loc='best', fontsize=6) plt.savefig(figure_file) + if __name__ == '__main__': parser = argparse.ArgumentParser(description='compute det curve') parser.add_argument('--test_data', required=True, help='label file') - parser.add_argument('--keywords', type=str, default=None, + parser.add_argument('--keywords', + type=str, + default=None, help='keywords, split with comma(,)') - parser.add_argument('--token_file', type=str, default=None, + parser.add_argument('--token_file', + type=str, + default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, + parser.add_argument('--lexicon_file', + type=str, + default=None, help='the path of lexicon.txt') parser.add_argument('--score_file', required=True, help='score file') - parser.add_argument('--step', type=float, default=0.01, + parser.add_argument('--step', + type=float, + default=0.01, help='threshold step') - parser.add_argument('--window_shift', type=int, default=50, + parser.add_argument('--window_shift', + type=int, + default=50, help='window_shift is used to ' - 'skip the frames after triggered') + 'skip the frames after triggered') parser.add_argument('--stats_dir', required=False, default=None, help='false reject/alarm stats dir, ' - 'default in score_file') + 'default in score_file') parser.add_argument('--det_curve_path', required=False, default=None, help='det curve path, default is stats_dir/det.png') - parser.add_argument( - '--xlim', - type=int, - default=5, - help='xlim:range of x-axis, x is false alarm per hour') + parser.add_argument('--xlim', + type=int, + default=5, + help='xlim:range of x-axis, x is false alarm per hour') parser.add_argument('--x_step', type=int, default=1, help='step on x-axis') - parser.add_argument( - '--ylim', - type=int, - default=35, - help='ylim:range of y-axis, y is false rejection rate') + parser.add_argument('--ylim', + type=int, + default=35, + help='ylim:range of y-axis, y is false rejection rate') parser.add_argument('--y_step', type=int, default=5, help='step on y-axis') args = parser.parse_args() @@ -206,8 +219,8 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5): strs, indexes = query_token_set(keyword, token_table, lexicon_table) true_keywords[keyword] = ''.join(strs) - keyword_filler_table = load_label_and_score( - keywords_list, args.test_data, args.score_file, true_keywords) + keyword_filler_table = load_label_and_score(keywords_list, args.test_data, + args.score_file, true_keywords) for keyword in keywords_list: keyword = true_keywords[keyword] @@ -226,7 +239,7 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5): keyword_dur / 3600.0, keyword_num)) logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0)) - if args.stats_dir : + if args.stats_dir: stats_dir = args.stats_dir else: stats_dir = os.path.dirname(args.score_file) @@ -247,8 +260,8 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5): num_false_alarm = 0 # transverse the all filler_table - for key, confi in keyword_filler_table[ - keyword]['filler_table'].items(): + for key, confi in keyword_filler_table[keyword][ + 'filler_table'].items(): if confi >= threshold: num_false_alarm += 1 # print(f'false alarm: {keyword}, {key}, {confi}') @@ -263,9 +276,9 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5): fout.write('{:.3f} {:.6f} {:.6f}\n'.format( threshold, false_alarm_per_hour, false_reject_rate)) threshold += args.step - if args.det_curve_path : + if args.det_curve_path: det_curve_path = args.det_curve_path else: det_curve_path = os.path.join(stats_dir, 'det.png') - plot_det(stats_dir, det_curve_path, - args.xlim, args.x_step, args.ylim, args.y_step) + plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, + args.y_step) diff --git a/wekws/bin/export_onnx.py b/wekws/bin/export_onnx.py index f998774..dd9b458 100644 --- a/wekws/bin/export_onnx.py +++ b/wekws/bin/export_onnx.py @@ -58,17 +58,12 @@ def main(): dtype=torch.float) if is_fsmn: cache = cache.unsqueeze(-1).expand(-1, -1, -1, num_layers) + dynamic_axes = {'input': {1: 'T'}, 'output': {1: 'T'}} torch.onnx.export(model, (dummy_input, cache), args.onnx_model, input_names=['input', 'cache'], output_names=['output', 'r_cache'], - dynamic_axes={ - 'input': { - 1: 'T' - }, - 'output': { - 1: 'T' - }}, + dynamic_axes=dynamic_axes, opset_version=13, verbose=False, do_constant_folding=True) diff --git a/wekws/bin/plot_det_curve.py b/wekws/bin/plot_det_curve.py index 62b77f7..1a03650 100644 --- a/wekws/bin/plot_det_curve.py +++ b/wekws/bin/plot_det_curve.py @@ -30,14 +30,8 @@ def load_stats_file(stats_file): return np.array(values) -def plot_det_curve( - keywords, - stats_dir, - figure_file, - xlim, - x_step, - ylim, - y_step): +def plot_det_curve(keywords, stats_dir, figure_file, xlim, x_step, ylim, + y_step): plt.figure(dpi=200) plt.rcParams['xtick.direction'] = 'in' plt.rcParams['ytick.direction'] = 'in' @@ -61,26 +55,24 @@ def plot_det_curve( if __name__ == '__main__': parser = argparse.ArgumentParser(description='plot det curve') - parser.add_argument( - '--keywords_dict', - required=True, - help='path to the dictionary of keywords') - parser.add_argument('--stats_dir', required=True, help='dir of stats files') - parser.add_argument( - '--figure_file', - required=True, - help='path to save det curve') - parser.add_argument( - '--xlim', - type=int, - default=5, - help='xlim:range of x-axis, x is false alarm per hour') + parser.add_argument('--keywords_dict', + required=True, + help='path to the dictionary of keywords') + parser.add_argument('--stats_dir', + required=True, + help='dir of stats files') + parser.add_argument('--figure_file', + required=True, + help='path to save det curve') + parser.add_argument('--xlim', + type=int, + default=5, + help='xlim:range of x-axis, x is false alarm per hour') parser.add_argument('--x_step', type=int, default=1, help='step on x-axis') - parser.add_argument( - '--ylim', - type=int, - default=35, - help='ylim:range of y-axis, y is false rejection rate') + parser.add_argument('--ylim', + type=int, + default=35, + help='ylim:range of y-axis, y is false rejection rate') parser.add_argument('--y_step', type=int, default=5, help='step on y-axis') args = parser.parse_args() @@ -92,11 +84,5 @@ def plot_det_curve( if int(index) > -1: keywords.append(keyword) - plot_det_curve( - keywords, - args.stats_dir, - args.figure_file, - args.xlim, - args.x_step, - args.ylim, - args.y_step) + plot_det_curve(keywords, args.stats_dir, args.figure_file, args.xlim, + args.x_step, args.ylim, args.y_step) diff --git a/wekws/bin/score.py b/wekws/bin/score.py index 97e91e2..ba4c056 100644 --- a/wekws/bin/score.py +++ b/wekws/bin/score.py @@ -117,10 +117,10 @@ def main(): score = logits[i][:lengths[i]] for keyword_i in range(num_keywords): keyword_scores = score[:, keyword_i] - score_frames = ' '.join(['{:.6f}'.format(x) - for x in keyword_scores.tolist()]) - fout.write('{} {} {}\n'.format( - key, keyword_i, score_frames)) + score_frames = ' '.join( + ['{:.6f}'.format(x) for x in keyword_scores.tolist()]) + fout.write('{} {} {}\n'.format(key, keyword_i, + score_frames)) if batch_idx % 10 == 0: print('Progress batch {}'.format(batch_idx)) sys.stdout.flush() diff --git a/wekws/bin/score_ctc.py b/wekws/bin/score_ctc.py index 9be9da1..0f9b1d8 100644 --- a/wekws/bin/score_ctc.py +++ b/wekws/bin/score_ctc.py @@ -33,6 +33,7 @@ from wekws.model.loss import ctc_prefix_beam_search from tools.make_list import query_token_set, read_lexicon, read_token + def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') @@ -65,16 +66,23 @@ def get_args(): action='store_true', default=False, help='Use pinned memory buffers used for reading') - parser.add_argument('--keywords', type=str, default=None, + parser.add_argument('--keywords', + type=str, + default=None, help='the keywords, split with comma(,)') - parser.add_argument('--token_file', type=str, default=None, + parser.add_argument('--token_file', + type=str, + default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, + parser.add_argument('--lexicon_file', + type=str, + default=None, help='the path of lexicon.txt') args = parser.parse_args() return args + def is_sublist(main_list, check_list): if len(main_list) < len(check_list): return -1 @@ -172,8 +180,7 @@ def main(): for i in range(len(keys)): key = keys[i] score = logits[i][:lengths[i]] - hyps = ctc_prefix_beam_search(score, - lengths[i], + hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset) hit_keyword = None hit_score = 1.0 @@ -201,11 +208,10 @@ def main(): if hit_keyword is not None: fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) - logging.info( - f"batch:{batch_idx}_{i} detect {hit_keyword} " - f"in {key} from {start} to {end} frame. " - f"duration {end - start}, " - f"score {hit_score}, Activated.") + logging.info(f"batch:{batch_idx}_{i} detect {hit_keyword} " + f"in {key} from {start} to {end} frame. " + f"duration {end - start}, " + f"score {hit_score}, Activated.") else: fout.write('{} rejected\n'.format(key)) logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") diff --git a/wekws/bin/stream_kws_ctc.py b/wekws/bin/stream_kws_ctc.py index 66251fd..82d1f88 100644 --- a/wekws/bin/stream_kws_ctc.py +++ b/wekws/bin/stream_kws_ctc.py @@ -36,12 +36,18 @@ def get_args(): parser = argparse.ArgumentParser(description='detect keywords online.') parser.add_argument('--config', required=True, help='config file') - parser.add_argument('--wav_path', required=False, - default=None, help='test wave path.') - parser.add_argument('--wav_scp', required=False, - default=None, help='test wave scp.') - parser.add_argument('--result_file', required=False, - default=None, help='test result.') + parser.add_argument('--wav_path', + required=False, + default=None, + help='test wave path.') + parser.add_argument('--wav_scp', + required=False, + default=None, + help='test wave scp.') + parser.add_argument('--result_file', + required=False, + default=None, + help='test result.') parser.add_argument('--gpu', type=int, @@ -52,28 +58,34 @@ def get_args(): action='store_true', default=False, help='Use pinned memory buffers used for reading') - parser.add_argument('--keywords', type=str, default=None, + parser.add_argument('--keywords', + type=str, + default=None, help='the keywords, split with comma(,)') - parser.add_argument('--token_file', type=str, default=None, + parser.add_argument('--token_file', + type=str, + default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, + parser.add_argument('--lexicon_file', + type=str, + default=None, help='the path of lexicon.txt') parser.add_argument('--score_beam_size', default=3, type=int, help='The first prune beam, ' - 'filter out those frames with low scores.') + 'filter out those frames with low scores.') parser.add_argument('--path_beam_size', default=20, type=int, help='The second prune beam, ' - 'keep only path_beam_size candidates.') + 'keep only path_beam_size candidates.') parser.add_argument('--threshold', type=float, default=0.0, help='The threshold of kws. ' - 'If ctc_search probs exceed this value,' - 'the keyword will be activated.') + 'If ctc_search probs exceed this value,' + 'the keyword will be activated.') parser.add_argument('--min_frames', default=5, type=int, @@ -108,9 +120,8 @@ def is_sublist(main_list, check_list): else: return -1 -def ctc_prefix_beam_search(t, probs, - cur_hyps, - keywords_idxset, + +def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size): ''' @@ -172,7 +183,7 @@ def ctc_prefix_beam_search(t, probs, if not math.isclose(pb, 0.0, abs_tol=0.000001): # Update *s-s -> *ss, - is for blank - n_prefix = prefix + (s,) + n_prefix = prefix + (s, ) n_pb, n_pnb, nodes = next_hyps[n_prefix] n_pnb = n_pnb + pb * ps nodes = cur_nodes.copy() @@ -180,7 +191,7 @@ def ctc_prefix_beam_search(t, probs, prob=ps)) # to record token prob next_hyps[n_prefix] = (n_pb, n_pnb, nodes) else: - n_prefix = prefix + (s,) + n_prefix = prefix + (s, ) n_pb, n_pnb, nodes = next_hyps[n_prefix] if nodes: if ps > nodes[-1]['prob']: # update frame and prob @@ -197,16 +208,30 @@ def ctc_prefix_beam_search(t, probs, next_hyps[n_prefix] = (n_pb, n_pnb, nodes) # 2.2 Second beam prune - next_hyps = sorted( - next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) + next_hyps = sorted(next_hyps.items(), + key=lambda x: (x[1][0] + x[1][1]), + reverse=True) return next_hyps + class KeyWordSpotter(torch.nn.Module): - def __init__(self, ckpt_path, config_path, token_path, lexicon_path, - threshold, min_frames=5, max_frames=250, interval_frames=50, - score_beam=3, path_beam=20, - gpu=-1, is_jit_model=False,): + + def __init__( + self, + ckpt_path, + config_path, + token_path, + lexicon_path, + threshold, + min_frames=5, + max_frames=250, + interval_frames=50, + score_beam=3, + path_beam=20, + gpu=-1, + is_jit_model=False, + ): super().__init__() os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) with open(config_path, 'r') as fin: @@ -216,25 +241,25 @@ def __init__(self, ckpt_path, config_path, token_path, lexicon_path, # feature related self.sample_rate = 16000 self.wave_remained = np.array([]) - self.num_mel_bins = dataset_conf[ - 'feature_extraction_conf']['num_mel_bins'] - self.frame_length = dataset_conf[ - 'feature_extraction_conf']['frame_length'] # in ms - self.frame_shift = dataset_conf[ - 'feature_extraction_conf']['frame_shift'] # in ms + self.num_mel_bins = dataset_conf['feature_extraction_conf'][ + 'num_mel_bins'] + self.frame_length = dataset_conf['feature_extraction_conf'][ + 'frame_length'] # in ms + self.frame_shift = dataset_conf['feature_extraction_conf'][ + 'frame_shift'] # in ms self.downsampling = dataset_conf.get('frame_skip', 1) - self.resolution = self.frame_shift / 1000 # in second + self.resolution = self.frame_shift / 1000 # in second # fsmn splice operation self.context_expansion = dataset_conf.get('context_expansion', False) self.left_context = 0 self.right_context = 0 if self.context_expansion: self.left_context = dataset_conf['context_expansion_conf']['left'] - self.right_context = dataset_conf['context_expansion_conf']['right'] + self.right_context = dataset_conf['context_expansion_conf'][ + 'right'] self.feature_remained = None self.feats_ctx_offset = 0 # after downsample, offset exist. - # model related if is_jit_model: model = torch.jit.load(ckpt_path) @@ -258,7 +283,6 @@ def __init__(self, ckpt_path, config_path, token_path, lexicon_path, f'{len(self.lexicon_table)} units loaded.') self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) - # decoding and detection related self.score_beam = score_beam self.path_beam = path_beam @@ -273,7 +297,7 @@ def __init__(self, ckpt_path, config_path, token_path, lexicon_path, self.hit_keyword = None self.activated = False - self.total_frames = 0 # frame offset, for absolute time + self.total_frames = 0 # frame offset, for absolute time self.last_active_pos = -1 # the last frame of being activated self.result = {} @@ -289,8 +313,8 @@ def set_keywords(self, keywords): keywords_strset = {''} keywords_tokenmap = {'': 0} for keyword in keywords_list: - strs, indexes = query_token_set( - keyword, self.token_table, self.lexicon_table) + strs, indexes = query_token_set(keyword, self.token_table, + self.lexicon_table) keywords_token[keyword] = {} keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) @@ -326,7 +350,7 @@ def accept_wave(self, wave): self.wave_remained = wave return None wave_tensor = torch.from_numpy(wave).float().to(self.device) - wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension + wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension feats = kaldi.fbank(wave_tensor, num_mel_bins=self.num_mel_bins, frame_length=self.frame_length, @@ -346,19 +370,19 @@ def accept_wave(self, wave): if self.feature_remained is None: # first chunk # pad first frame at the beginning, # replicate just support last dimension, so we do transpose. - feats_pad = F.pad( - feats.T, (self.left_context, 0), mode='replicate').T + feats_pad = F.pad(feats.T, (self.left_context, 0), + mode='replicate').T else: feats_pad = torch.cat((self.feature_remained, feats)) - ctx_frm = feats_pad.shape[0] - ( - self.right_context + self.right_context) + ctx_frm = feats_pad.shape[0] - (self.right_context + + self.right_context) ctx_win = (self.left_context + self.right_context + 1) ctx_dim = feats.shape[1] * ctx_win feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) for i in range(ctx_frm): - feats_ctx[i] = torch.cat( - tuple(feats_pad[i: i + ctx_win])).unsqueeze(0) + feats_ctx[i] = torch.cat(tuple( + feats_pad[i:i + ctx_win])).unsqueeze(0) # update feature remained, and feats self.feature_remained = \ @@ -376,9 +400,7 @@ def accept_wave(self, wave): def decode_keywords(self, t, probs): absolute_time = t + self.total_frames # search next_hyps depend on current probs and hyps. - next_hyps = ctc_prefix_beam_search(absolute_time, - probs, - self.cur_hyps, + next_hyps = ctc_prefix_beam_search(absolute_time, probs, self.cur_hyps, self.keywords_idxset, self.score_beam) # update cur_hyps. note: the hyps is sort by path score(pnb+pb), @@ -437,11 +459,10 @@ def execute_detection(self, t): f"is lower than {self.interval_frames}, Deactivated. ") elif self.hit_score < self.threshold: - logging.info( - f"Frame {absolute_time} detect {hit_keyword} " - f"from {start} to {end} frame. " - f"but {self.hit_score} " - f"is lower than {self.threshold}, Deactivated. ") + logging.info(f"Frame {absolute_time} detect {hit_keyword} " + f"from {start} to {end} frame. " + f"but {self.hit_score} " + f"is lower than {self.threshold}, Deactivated. ") elif self.min_frames > duration or duration > self.max_frames: logging.info( @@ -462,10 +483,10 @@ def forward(self, wave_chunk): feature = self.accept_wave(wave_chunk) if feature is None or feature.size(0) < 1: return {} # # the feature is not enough to get result. - feature = feature.unsqueeze(0) # add a batch dimension + feature = feature.unsqueeze(0) # add a batch dimension logits, self.in_cache = self.model(feature, self.in_cache) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) - probs = probs[0].cpu() # remove batch dimension + probs = probs[0].cpu() # remove batch dimension for (t, prob) in enumerate(probs): t *= self.downsampling self.decode_keywords(t, prob) @@ -503,26 +524,20 @@ def reset_all(self): self.feature_remained = None self.feats_ctx_offset = 0 # after downsample, offset exist. self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) - self.total_frames = 0 # frame offset, for absolute time + self.total_frames = 0 # frame offset, for absolute time self.last_active_pos = -1 # the last frame of being activated self.result = {} + def demo(): args = get_args() logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') - kws = KeyWordSpotter(args.checkpoint, - args.config, - args.token_file, - args.lexicon_file, - args.threshold, - args.min_frames, - args.max_frames, - args.interval_frames, - args.score_beam_size, - args.path_beam_size, - args.gpu, + kws = KeyWordSpotter(args.checkpoint, args.config, args.token_file, + args.lexicon_file, args.threshold, args.min_frames, + args.max_frames, args.interval_frames, + args.score_beam_size, args.path_beam_size, args.gpu, args.jit_model) # actually this could be done in __init__ method, @@ -543,7 +558,7 @@ def demo(): # We inference every 0.3 seconds, in streaming fashion. interval = int(0.3 * 16000) * 2 for i in range(0, len(wav), interval): - chunk_wav = wav[i: min(i + interval, len(wav))] + chunk_wav = wav[i:min(i + interval, len(wav))] result = kws.forward(chunk_wav) print(result) @@ -574,7 +589,7 @@ def demo(): # We inference every 0.3 seconds, in streaming fashion. interval = int(0.3 * 16000) * 2 for i in range(0, len(wav), interval): - chunk_wav = wav[i: min(i + interval, len(wav))] + chunk_wav = wav[i:min(i + interval, len(wav))] result = kws.forward(chunk_wav) if 'state' in result and result['state'] == 1: activated = True @@ -588,9 +603,9 @@ def demo(): if fout: fout.write('{} rejected\n'.format(utt_name)) - if fout: fout.close() + if __name__ == '__main__': demo() diff --git a/wekws/bin/stream_score_ctc.py b/wekws/bin/stream_score_ctc.py index c03e66b..ef3dd29 100644 --- a/wekws/bin/stream_score_ctc.py +++ b/wekws/bin/stream_score_ctc.py @@ -66,28 +66,34 @@ def get_args(): action='store_true', default=False, help='Use pinned memory buffers used for reading') - parser.add_argument('--keywords', type=str, default=None, + parser.add_argument('--keywords', + type=str, + default=None, help='the keywords, split with comma(,)') - parser.add_argument('--token_file', type=str, default=None, + parser.add_argument('--token_file', + type=str, + default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, + parser.add_argument('--lexicon_file', + type=str, + default=None, help='the path of lexicon.txt') parser.add_argument('--score_beam_size', default=3, type=int, help='The first prune beam, f' - 'ilter out those frames with low scores.') + 'ilter out those frames with low scores.') parser.add_argument('--path_beam_size', default=20, type=int, help='The second prune beam, ' - 'keep only path_beam_size candidates.') + 'keep only path_beam_size candidates.') parser.add_argument('--threshold', type=float, default=0.0, help='The threshold of kws. ' - 'If ctc_search probs exceed this value,' - 'the keyword will be activated.') + 'If ctc_search probs exceed this value,' + 'the keyword will be activated.') parser.add_argument('--min_frames', default=5, type=int, @@ -215,7 +221,7 @@ def main(): # 2. CTC beam search step by step for t in range(0, maxlen): probs = ctc_probs[t] # (vocab_size,) - t *= downsampling_factor # the real time + t *= downsampling_factor # the real time # key: prefix, value (pb, pnb), default value(-inf, -inf) next_hyps = defaultdict(lambda: (0.0, 0.0, [])) @@ -225,8 +231,8 @@ def main(): # filter prob score that is too small filter_probs = [] filter_index = [] - for prob, idx in zip( - top_k_probs.tolist(), top_k_index.tolist()): + for prob, idx in zip(top_k_probs.tolist(), + top_k_index.tolist()): if keywords_idxset is not None: if prob > 0.05 and idx in keywords_idxset: filter_probs.append(prob) @@ -250,7 +256,8 @@ def main(): nodes = cur_nodes.copy() next_hyps[prefix] = (n_pb, n_pnb, nodes) elif s == last: - if not math.isclose(pnb, 0.0, abs_tol=0.000001): + if not math.isclose(pnb, 0.0, + abs_tol=0.000001): # Update *ss -> *s; n_pb, n_pnb, nodes = next_hyps[prefix] n_pnb = n_pnb + pnb * ps @@ -263,15 +270,15 @@ def main(): if not math.isclose(pb, 0.0, abs_tol=0.000001): # Update *s-s -> *ss, - is for blank - n_prefix = prefix + (s,) + n_prefix = prefix + (s, ) n_pb, n_pnb, nodes = next_hyps[n_prefix] n_pnb = n_pnb + pb * ps nodes = cur_nodes.copy() - nodes.append(dict( - token=s, frame=t, prob=ps)) + nodes.append( + dict(token=s, frame=t, prob=ps)) next_hyps[n_prefix] = (n_pb, n_pnb, nodes) else: - n_prefix = prefix + (s,) + n_prefix = prefix + (s, ) n_pb, n_pnb, nodes = next_hyps[n_prefix] if nodes: # update frame and prob @@ -280,19 +287,19 @@ def main(): # nodes[-1]['frame'] = t # avoid change other beam has this node. nodes.pop() - nodes.append(dict( - token=s, frame=t, prob=ps)) + nodes.append( + dict(token=s, frame=t, prob=ps)) else: nodes = cur_nodes.copy() - nodes.append(dict( - token=s, frame=t, prob=ps)) + nodes.append( + dict(token=s, frame=t, prob=ps)) n_pnb = n_pnb + pb * ps + pnb * ps next_hyps[n_prefix] = (n_pb, n_pnb, nodes) # 2.2 Second beam prune - next_hyps = sorted( - next_hyps.items(), - key=lambda x: (x[1][0] + x[1][1]), reverse=True) + next_hyps = sorted(next_hyps.items(), + key=lambda x: (x[1][0] + x[1][1]), + reverse=True) cur_hyps = next_hyps[:args.path_beam_size] @@ -310,8 +317,8 @@ def main(): if offset != -1: hit_keyword = word start = prefix_nodes[offset]['frame'] - end = prefix_nodes[ - offset + len(lab) - 1]['frame'] + end = prefix_nodes[offset + len(lab) - + 1]['frame'] for idx in range(offset, offset + len(lab)): hit_score *= prefix_nodes[idx]['prob'] break diff --git a/wekws/bin/train.py b/wekws/bin/train.py index 025ebf9..f226d1e 100644 --- a/wekws/bin/train.py +++ b/wekws/bin/train.py @@ -77,12 +77,8 @@ def get_args(): default=100, type=int, help='prefetch number') - parser.add_argument('--reverb_lmdb', - default=None, - help='reverb lmdb file') - parser.add_argument('--noise_lmdb', - default=None, - help='noise lmdb file') + parser.add_argument('--reverb_lmdb', default=None, help='reverb lmdb file') + parser.add_argument('--noise_lmdb', default=None, help='noise lmdb file') args = parser.parse_args() return args diff --git a/wekws/dataset/dataset.py b/wekws/dataset/dataset.py index 897b87c..bbcef2c 100644 --- a/wekws/dataset/dataset.py +++ b/wekws/dataset/dataset.py @@ -24,6 +24,7 @@ class Processor(IterableDataset): + def __init__(self, source, f, *args, **kw): assert callable(f) self.source = source @@ -48,6 +49,7 @@ def apply(self, f): class DistributedSampler: + def __init__(self, shuffle=True, partition=True): self.epoch = -1 self.update() @@ -96,6 +98,7 @@ def sample(self, data): class DataList(IterableDataset): + def __init__(self, lists, shuffle=True, partition=True): self.lists = lists self.sampler = DistributedSampler(shuffle, partition) @@ -113,7 +116,8 @@ def __iter__(self): yield data -def Dataset(data_list_file, conf, +def Dataset(data_list_file, + conf, partition=True, reverb_lmdb=None, noise_lmdb=None): @@ -144,12 +148,12 @@ def Dataset(data_list_file, conf, dataset = Processor(dataset, processor.speed_perturb) if reverb_lmdb and conf.get('reverb_prob', 0) > 0: reverb_data = LmdbData(reverb_lmdb) - dataset = Processor(dataset, processor.add_reverb, - reverb_data, conf['reverb_prob']) + dataset = Processor(dataset, processor.add_reverb, reverb_data, + conf['reverb_prob']) if noise_lmdb and conf.get('noise_prob', 0) > 0: noise_data = LmdbData(noise_lmdb) - dataset = Processor(dataset, processor.add_noise, - noise_data, conf['noise_prob']) + dataset = Processor(dataset, processor.add_noise, noise_data, + conf['noise_prob']) feature_extraction_conf = conf.get('feature_extraction_conf', {}) if feature_extraction_conf['feature_type'] == 'mfcc': dataset = Processor(dataset, processor.compute_mfcc, diff --git a/wekws/dataset/processor.py b/wekws/dataset/processor.py index bb1dde7..313ff96 100644 --- a/wekws/dataset/processor.py +++ b/wekws/dataset/processor.py @@ -263,6 +263,7 @@ def shuffle(data, shuffle_size=1000): for x in buf: yield x + def context_expansion(data, left=1, right=1): """ expand left and right frames Args: @@ -287,8 +288,8 @@ def context_expansion(data, left=1, right=1): # replication pad left margin for idx in range(left): for cpx in range(left - idx): - feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1) - * feats.shape[1]] = feats_ctx[left, :feats.shape[1]] + feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1) * + feats.shape[1]] = feats_ctx[left, :feats.shape[1]] feats_ctx = feats_ctx[:feats_ctx.shape[0] - right] sample['feat'] = feats_ctx @@ -309,6 +310,7 @@ def frame_skip(data, skip_rate=1): sample['feat'] = feats_skip yield sample + def batch(data, batch_size=16): """ Static batch the data by `batch_size` @@ -354,17 +356,19 @@ def padding(data): if isinstance(sample[0]['label'], int): padded_labels = torch.tensor([sample[i]['label'] for i in order], dtype=torch.int32) - label_lengths = torch.tensor([1 for i in order], - dtype=torch.int32) + label_lengths = torch.tensor([1 for i in order], dtype=torch.int32) else: sorted_labels = [ - torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order + torch.tensor(sample[i]['label'], dtype=torch.int32) + for i in order ] - label_lengths = torch.tensor([len(sample[i]['label']) for i in order], - dtype=torch.int32) - padded_labels = pad_sequence( - sorted_labels, batch_first=True, padding_value=-1) - yield (sorted_keys, padded_feats, padded_labels, feats_lengths, label_lengths) + label_lengths = torch.tensor( + [len(sample[i]['label']) for i in order], dtype=torch.int32) + padded_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + yield (sorted_keys, padded_feats, padded_labels, feats_lengths, + label_lengths) def add_reverb(data, reverb_source, aug_prob): diff --git a/wekws/model/classifier.py b/wekws/model/classifier.py index 190b30e..d76e2c9 100644 --- a/wekws/model/classifier.py +++ b/wekws/model/classifier.py @@ -18,6 +18,7 @@ class GlobalClassifier(nn.Module): """Add a global average pooling before the classifier""" + def __init__(self, classifier: nn.Module): super(GlobalClassifier, self).__init__() self.classifier = classifier @@ -29,6 +30,7 @@ def forward(self, x: torch.Tensor): class LastClassifier(nn.Module): """Select last frame to do the classification""" + def __init__(self, classifier: nn.Module): super(LastClassifier, self).__init__() self.classifier = classifier @@ -37,8 +39,10 @@ def forward(self, x: torch.Tensor): x = x[:, -1, :] return self.classifier(x) + class ElementClassifier(nn.Module): """Classify all the frames in an utterance""" + def __init__(self, classifier: nn.Module): super(ElementClassifier, self).__init__() self.classifier = classifier @@ -46,8 +50,10 @@ def __init__(self, classifier: nn.Module): def forward(self, x: torch.Tensor): return self.classifier(x) + class LinearClassifier(nn.Module): """ Wrapper of Linear """ + def __init__(self, input_dim, output_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim) diff --git a/wekws/model/cmvn.py b/wekws/model/cmvn.py index 2e211aa..18496e1 100644 --- a/wekws/model/cmvn.py +++ b/wekws/model/cmvn.py @@ -17,6 +17,7 @@ class GlobalCMVN(torch.nn.Module): + def __init__(self, mean: torch.Tensor, istd: torch.Tensor, diff --git a/wekws/model/fsmn.py b/wekws/model/fsmn.py index 555c760..e62132d 100644 --- a/wekws/model/fsmn.py +++ b/wekws/model/fsmn.py @@ -39,8 +39,7 @@ def __init__(self, input_dim, output_dim): self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, - input: Tuple[torch.Tensor, torch.Tensor]): + def forward(self, input: Tuple[torch.Tensor, torch.Tensor]): if isinstance(input, tuple): input, in_cache = input else: @@ -103,8 +102,7 @@ def __init__(self, input_dim, output_dim): self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, - input: Tuple[torch.Tensor, torch.Tensor]): + def forward(self, input: Tuple[torch.Tensor, torch.Tensor]): if isinstance(input, tuple): input, in_cache = input else: @@ -195,54 +193,54 @@ def __init__( self.lstride = lstride self.rstride = rstride - self.conv_left = nn.Conv2d( - self.dim, - self.dim, [lorder, 1], - dilation=[lstride, 1], - groups=self.dim, - bias=False) + self.conv_left = nn.Conv2d(self.dim, + self.dim, [lorder, 1], + dilation=[lstride, 1], + groups=self.dim, + bias=False) if rorder > 0: - self.conv_right = nn.Conv2d( - self.dim, - self.dim, [rorder, 1], - dilation=[rstride, 1], - groups=self.dim, - bias=False) + self.conv_right = nn.Conv2d(self.dim, + self.dim, [rorder, 1], + dilation=[rstride, 1], + groups=self.dim, + bias=False) else: self.conv_right = None self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, - input: Tuple[torch.Tensor, torch.Tensor]): + def forward(self, input: Tuple[torch.Tensor, torch.Tensor]): if isinstance(input, tuple): input, in_cache = input - else : + else: in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) x = torch.unsqueeze(input, 1) x_per = x.permute(0, 3, 2, 1) - if in_cache is None or len(in_cache) == 0 : - x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride - + self.rorder * self.rstride, 0]) + if in_cache is None or len(in_cache) == 0: + x_pad = F.pad(x_per, [ + 0, 0, + (self.lorder - 1) * self.lstride + self.rorder * self.rstride, + 0 + ]) else: in_cache = in_cache.to(x_per.device) x_pad = torch.cat((in_cache, x_per), dim=2) - in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride - + self.rorder * self.rstride):, :] + in_cache = x_pad[:, :, -( + (self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :] y_left = x_pad[:, :, :-self.rorder * self.rstride, :] y_left = self.quant(y_left) y_left = self.conv_left(y_left) y_left = self.dequant(y_left) - out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder * + out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left if self.conv_right is not None: # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) - y_right = x_pad[:, :, -( - x_per.size(2) + self.rorder * self.rstride):, :] + y_right = x_pad[:, :, + -(x_per.size(2) + self.rorder * self.rstride):, :] y_right = y_right[:, :, self.rstride:, :] y_right = self.quant(y_right) y_right = self.conv_right(y_right) @@ -347,11 +345,10 @@ def __init__(self, input_dim, output_dim): self.relu = nn.ReLU() self.dropout = nn.Dropout(0.1) - def forward(self, - input: Tuple[torch.Tensor, torch.Tensor]): + def forward(self, input: Tuple[torch.Tensor, torch.Tensor]): if isinstance(input, tuple): input, in_cache = input - else : + else: in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) out = self.relu(input) # out = self.dropout(out) @@ -391,11 +388,10 @@ def _build_repeats( rstride=1, ): repeats = [ - nn.Sequential( - LinearTransform(linear_dim, proj_dim), - FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1), - AffineTransform(proj_dim, linear_dim), - RectifiedLinear(linear_dim, linear_dim)) + nn.Sequential(LinearTransform(linear_dim, proj_dim), + FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1), + AffineTransform(proj_dim, linear_dim), + RectifiedLinear(linear_dim, linear_dim)) for i in range(fsmn_layers) ] @@ -474,11 +470,15 @@ def forward( in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size """ - if in_cache is None or len(in_cache) == 0 : - in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) - for _ in range(len(self.fsmn))] + if in_cache is None or len(in_cache) == 0: + in_cache = [ + torch.zeros(0, 0, 0, 0, dtype=torch.float) + for _ in range(len(self.fsmn)) + ] else: - in_cache = [in_cache[:, :, :, i: i + 1] for i in range(in_cache.size(-1))] + in_cache = [ + in_cache[:, :, :, i:i + 1] for i in range(in_cache.size(-1)) + ] input = (input, in_cache) x1 = self.in_linear1(input) x2 = self.in_linear2(x1) diff --git a/wekws/model/kws_model.py b/wekws/model/kws_model.py index 349690b..4918ea5 100644 --- a/wekws/model/kws_model.py +++ b/wekws/model/kws_model.py @@ -40,6 +40,7 @@ class KWSModel(nn.Module): nn.Sigmoid for wakeup word nn.Identity for speech command dataset """ + def __init__( self, idim: int, @@ -74,11 +75,11 @@ def forward( x = self.activation(x) return x, out_cache - def forward_softmax(self, - x: torch.Tensor, - in_cache: torch.Tensor = torch.zeros( - 0, 0, 0, dtype=torch.float) - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward_softmax( + self, + x: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_cmvn is not None: x = self.global_cmvn(x) x = self.preprocessing(x) diff --git a/wekws/model/loss.py b/wekws/model/loss.py index 42045a0..05d430b 100644 --- a/wekws/model/loss.py +++ b/wekws/model/loss.py @@ -98,6 +98,7 @@ def acc_frame( correct = pred.eq(target.long().view_as(pred)).sum().item() return correct * 100.0 / logits.size(0) + def acc_utterance(logits: torch.Tensor, target: torch.Tensor, logits_length: torch.Tensor, target_length: torch.Tensor): if logits is None: @@ -127,8 +128,9 @@ def acc_utterance(logits: torch.Tensor, target: torch.Tensor, total_sub += result['sub'] total_del += result['del'] - return float(total_word - total_ins - total_sub - - total_del) * 100.0 / total_word + return float(total_word - total_ins - total_sub - + total_del) * 100.0 / total_word + def ctc_loss(logits: torch.Tensor, target: torch.Tensor, @@ -152,12 +154,16 @@ def ctc_loss(logits: torch.Tensor, # logits: (B, L, D) -> (L, B, D) logits = logits.transpose(0, 1) logits = logits.log_softmax(2) - loss = F.ctc_loss( - logits, target, logits_lengths, target_lengths, reduction='sum') + loss = F.ctc_loss(logits, + target, + logits_lengths, + target_lengths, + reduction='sum') loss = loss / logits.size(1) # batch mean return loss, acc + def cross_entropy(logits: torch.Tensor, target: torch.Tensor): """ Cross Entropy Loss Attributes: @@ -174,13 +180,15 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor): return loss, acc -def criterion(type: str, - logits: torch.Tensor, - target: torch.Tensor, - lengths: torch.Tensor, - target_lengths: torch.Tensor = None, - min_duration: int = 0, - validation: bool = False, ): +def criterion( + type: str, + logits: torch.Tensor, + target: torch.Tensor, + lengths: torch.Tensor, + target_lengths: torch.Tensor = None, + min_duration: int = 0, + validation: bool = False, +): if type == 'ce': loss, acc = cross_entropy(logits, target) return loss, acc @@ -188,12 +196,13 @@ def criterion(type: str, loss, acc = max_pooling_loss(logits, target, lengths, min_duration) return loss, acc elif type == 'ctc': - loss, acc = ctc_loss( - logits, target, lengths, target_lengths, validation) + loss, acc = ctc_loss(logits, target, lengths, target_lengths, + validation) return loss, acc else: exit(1) + def ctc_prefix_beam_search( logits: torch.Tensor, logits_lengths: torch.Tensor, @@ -293,8 +302,9 @@ def ctc_prefix_beam_search( next_hyps[n_prefix] = (n_pb, n_pnb, nodes) # 2.2 Second beam prune - next_hyps = sorted( - next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) + next_hyps = sorted(next_hyps.items(), + key=lambda x: (x[1][0] + x[1][1]), + reverse=True) cur_hyps = next_hyps[:path_beam_size] @@ -430,10 +440,9 @@ def calculate(self, lab, rec): elif self.space[i][j]['error'] == 'non': # starting point break else: # shouldn't reach here - print( - 'this should not happen, ' - 'i = {i} , j = {j} , error = {error}' - .format(i=i, j=j, error=self.space[i][j]['error'])) + print('this should not happen, ' + 'i = {i} , j = {j} , error = {error}'.format( + i=i, j=j, error=self.space[i][j]['error'])) return result def overall(self): diff --git a/wekws/model/mdtc.py b/wekws/model/mdtc.py index 090dde5..2891880 100644 --- a/wekws/model/mdtc.py +++ b/wekws/model/mdtc.py @@ -22,6 +22,7 @@ class DSDilatedConv1d(nn.Module): """Dilated Depthwise-Separable Convolution""" + def __init__( self, in_channels: int, @@ -59,6 +60,7 @@ def forward(self, inputs: torch.Tensor): class TCNBlock(nn.Module): + def __init__( self, in_channels: int, @@ -120,6 +122,7 @@ def forward( class TCNStack(nn.Module): + def __init__( self, in_channels: int, @@ -205,6 +208,7 @@ class MDTC(nn.Module): extracts multi-scale features from different hidden layers of MDTC with different receptive fields. """ + def __init__( self, stack_num: int, @@ -263,7 +267,8 @@ def forward( out_caches.append(c_out) offset += block.padding - outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype) + outputs = torch.zeros_like(outputs_list[-1], + dtype=outputs_list[-1].dtype) for x in outputs_list: outputs += x outputs = outputs.transpose(1, 2) # (B, T, D) diff --git a/wekws/model/subsampling.py b/wekws/model/subsampling.py index 06d6311..cbe6ab7 100644 --- a/wekws/model/subsampling.py +++ b/wekws/model/subsampling.py @@ -19,6 +19,7 @@ class SubsamplingBase(torch.nn.Module): + def __init__(self): super().__init__() self.subsampling_rate = 1 @@ -27,6 +28,7 @@ def __init__(self): class NoSubsampling(SubsamplingBase): """No subsampling in accordance to the 'none' preprocessing """ + def __init__(self): super().__init__() @@ -37,6 +39,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LinearSubsampling1(SubsamplingBase): """Linear transform the input without subsampling """ + def __init__(self, idim: int, odim: int): super().__init__() self.out = torch.nn.Sequential( @@ -61,6 +64,7 @@ def fuse_modules(self): class Conv1dSubsampling1(SubsamplingBase): """Conv1d transform without subsampling """ + def __init__(self, idim: int, odim: int): super().__init__() self.out = torch.nn.Sequential( diff --git a/wekws/model/tcn.py b/wekws/model/tcn.py index 36a6a29..85a6d14 100644 --- a/wekws/model/tcn.py +++ b/wekws/model/tcn.py @@ -21,6 +21,7 @@ class Block(nn.Module): + def __init__(self, channel: int, kernel_size: int, @@ -64,6 +65,7 @@ def fuse_modules(self): class CnnBlock(Block): + def __init__(self, channel: int, kernel_size: int, @@ -89,6 +91,7 @@ def fuse_modules(self): class DsCnnBlock(Block): """ Depthwise Separable Convolution """ + def __init__(self, channel: int, kernel_size: int, @@ -117,6 +120,7 @@ def fuse_modules(self): class TCN(nn.Module): + def __init__(self, num_layers: int, channel: int, diff --git a/wekws/utils/checkpoint.py b/wekws/utils/checkpoint.py index 4bbd3e2..db4e466 100644 --- a/wekws/utils/checkpoint.py +++ b/wekws/utils/checkpoint.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging import os import re diff --git a/wekws/utils/cmvn.py b/wekws/utils/cmvn.py index b28fad3..b19c2d7 100644 --- a/wekws/utils/cmvn.py +++ b/wekws/utils/cmvn.py @@ -44,6 +44,7 @@ def load_cmvn(json_cmvn_file): cmvn = np.array([means, variance]) return cmvn + def load_kaldi_cmvn(cmvn_file): """ Load the kaldi format cmvn stats file and no need to calculate diff --git a/wekws/utils/executor.py b/wekws/utils/executor.py index 2283e71..51d736f 100644 --- a/wekws/utils/executor.py +++ b/wekws/utils/executor.py @@ -21,6 +21,7 @@ class Executor: + def __init__(self): self.step = 0 @@ -44,7 +45,10 @@ def train(self, model, optimizer, data_loader, device, writer, args): continue logits, _ = model(feats) loss_type = args.get('criterion', 'max_pooling') - loss, acc = criterion(loss_type, logits, target, feats_lengths, + loss, acc = criterion(loss_type, + logits, + target, + feats_lengths, target_lengths=label_lengths, min_duration=min_duration, validation=False) @@ -80,7 +84,9 @@ def cv(self, model, data_loader, device, args): continue logits, _ = model(feats) loss, acc = criterion(args.get('criterion', 'max_pooling'), - logits, target, feats_lengths, + logits, + target, + feats_lengths, target_lengths=label_lengths, min_duration=0, validation=True) From c5efb63df75edd7aea77165081f406ba65415259 Mon Sep 17 00:00:00 2001 From: cdliang11 <1404056823@qq.com> Date: Thu, 6 Feb 2025 14:37:23 +0800 Subject: [PATCH 3/4] [lint] cpplint==1.6.1 --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1459fec..502d20d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -80,7 +80,7 @@ jobs: - name: Run cpplint run: | set -eux - pip install cpplint + pip install cpplint==1.6.1 cpplint --version cpplint --recursive . if [ $? != 0 ]; then exit 1; fi From 5df4a5bcac1b8b1e3b3550794cad0bfa0b9a30b9 Mon Sep 17 00:00:00 2001 From: cdliang11 <1404056823@qq.com> Date: Thu, 6 Feb 2025 14:48:22 +0800 Subject: [PATCH 4/4] [lint] remove sphinx --- .github/workflows/doc.yml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 68bce32..3937a57 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -8,13 +8,6 @@ jobs: steps: - uses: actions/checkout@v1 - # install sphinx related package and build sphinx files - - uses: ammaraskar/sphinx-action@master - with: - docs-folder: "docs/" - pre-build-command: "pip install sphinx-markdown-tables nbsphinx jinja2 recommonmark sphinx_rtd_theme" - - # add .nojekyll to notice Pages use the _* dirs - name: copy the generated site if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}