Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lint] use pre-commit to auto check and lint #183

Merged
merged 4 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ jobs:
steps:
- uses: actions/checkout@v1

# install sphinx related package and build sphinx files
- uses: ammaraskar/sphinx-action@master
with:
docs-folder: "docs/"
pre-build-command: "pip install sphinx-markdown-tables nbsphinx jinja2 recommonmark sphinx_rtd_theme"


# add .nojekyll to notice Pages use the _* dirs
- name: copy the generated site
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.x
python-version: 3.9
architecture: x64
- name: Fetch Wenet
uses: actions/checkout@v1
Expand Down Expand Up @@ -80,7 +80,7 @@ jobs:
- name: Run cpplint
run: |
set -eux
pip install cpplint
pip install cpplint==1.6.1
cpplint --version
cpplint --recursive .
if [ $? != 0 ]; then exit 1; fi
Expand Down
23 changes: 23 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-yapf
rev: 'v0.32.0'
hooks:
- id: yapf
- repo: https://github.com/pycqa/flake8
rev: '3.8.2'
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: 'v17.0.6'
hooks:
- id: clang-format
args: ['--style=file']
exclude: 'runtime/android/app/src/.*\.(json|java|js|m|mm|proto)'
- repo: https://github.com/cpplint/cpplint
rev: '1.6.1'
hooks:
- id: cpplint
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pip install torch torchaudio

``` sh
pip install -r requirements.txt
pre-commit install # for clean and tidy code
```

## Dataset
Expand Down
4 changes: 0 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
# import sys
# sys.path.insert(0, os.path.abspath('.'))


# -- Project information -----------------------------------------------------

project = 'Wenet'
copyright = '2020, wenet-team'
author = 'wenet-team'


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
Expand All @@ -43,7 +41,6 @@
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']


# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
source_suffix = {
Expand All @@ -57,7 +54,6 @@
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ pyflakes==2.2.0
lmdb
scipy
tqdm
pre-commit==3.5.0
22 changes: 11 additions & 11 deletions runtime/android/app/src/main/cpp/wekws.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ void set_input_finished() {
// }

void start_spot() {
std::vector<std::vector<float>> feats;
feature_pipeline->Read(80, &feats);
std::vector<std::vector<float>> prob;
spotter->Forward(feats, &prob);
float max_prob = 0.0;
for (int t = 0; t < prob.size(); t++) {
for (int j = 0; j < prob[t].size(); j++) {
max_prob = std::max(prob[t][j], max_prob);
}
std::vector<std::vector<float>> feats;
feature_pipeline->Read(80, &feats);
std::vector<std::vector<float>> prob;
spotter->Forward(feats, &prob);
float max_prob = 0.0;
for (int t = 0; t < prob.size(); t++) {
for (int j = 0; j < prob[t].size(); j++) {
max_prob = std::max(prob[t][j], max_prob);
}
result = std::to_string(offset) + " prob: " + std::to_string(max_prob);
offset += prob.size();
}
result = std::to_string(offset) + " prob: " + std::to_string(max_prob);
offset += prob.size();
}

jstring get_result(JNIEnv* env, jobject) {
Expand Down
1 change: 0 additions & 1 deletion runtime/core/bin/kws_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.


#include <iostream>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion runtime/core/frontend/feature_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#include <vector>

#include "frontend/fbank.h"
#include "utils/log.h"
#include "utils/blocking_queue.h"
#include "utils/log.h"

namespace wenet {

Expand Down
34 changes: 15 additions & 19 deletions runtime/core/kws/keyword_spotting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.


#include "kws/keyword_spotting.h"

#include <iostream>
Expand All @@ -35,30 +34,27 @@ KeywordSpotting::KeywordSpotting(const std::string& model_path) {
out_names_ = {"output", "r_cache"};
auto metadata = session_->GetModelMetadata();
Ort::AllocatorWithDefaultOptions allocator;
cache_dim_ = std::stoi(metadata.LookupCustomMetadataMap("cache_dim",
allocator));
cache_len_ = std::stoi(metadata.LookupCustomMetadataMap("cache_len",
allocator));
cache_dim_ =
std::stoi(metadata.LookupCustomMetadataMap("cache_dim", allocator));
cache_len_ =
std::stoi(metadata.LookupCustomMetadataMap("cache_len", allocator));
std::cout << "Kws Model Info:" << std::endl
<< "\tcache_dim: " << cache_dim_ << std::endl
<< "\tcache_len: " << cache_len_ << std::endl;
Reset();
}


void KeywordSpotting::Reset() {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
cache_.resize(cache_dim_ * cache_len_, 0.0);
const int64_t cache_shape[] = {1, cache_dim_, cache_len_};
cache_ort_ = Ort::Value::CreateTensor<float>(
memory_info, cache_.data(), cache_.size(), cache_shape, 3);
cache_ort_ = Ort::Value::CreateTensor<float>(memory_info, cache_.data(),
cache_.size(), cache_shape, 3);
}


void KeywordSpotting::Forward(
const std::vector<std::vector<float>>& feats,
std::vector<std::vector<float>>* prob) {
void KeywordSpotting::Forward(const std::vector<std::vector<float>>& feats,
std::vector<std::vector<float>>* prob) {
prob->clear();
if (feats.size() == 0) return;
Ort::MemoryInfo memory_info =
Expand All @@ -78,9 +74,9 @@ void KeywordSpotting::Forward(
inputs.emplace_back(std::move(feats_ort));
inputs.emplace_back(std::move(cache_ort_));
// ort_outputs.size() == 2
std::vector<Ort::Value> ort_outputs = session_->Run(
Ort::RunOptions{nullptr}, in_names_.data(), inputs.data(),
inputs.size(), out_names_.data(), out_names_.size());
std::vector<Ort::Value> ort_outputs =
session_->Run(Ort::RunOptions{nullptr}, in_names_.data(), inputs.data(),
inputs.size(), out_names_.data(), out_names_.size());

// 3. Update cache
cache_ort_ = std::move(ort_outputs[1]);
Expand All @@ -92,9 +88,9 @@ void KeywordSpotting::Forward(
int output_dim = type_info.GetShape()[2];
prob->resize(num_outputs);
for (int i = 0; i < num_outputs; i++) {
(*prob)[i].resize(output_dim);
memcpy((*prob)[i].data(), data + i * output_dim,
sizeof(float) * output_dim);
(*prob)[i].resize(output_dim);
memcpy((*prob)[i].data(), data + i * output_dim,
sizeof(float) * output_dim);
}
}

Expand Down
2 changes: 0 additions & 2 deletions runtime/core/kws/keyword_spotting.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.


#ifndef KWS_KEYWORD_SPOTTING_H_
#define KWS_KEYWORD_SPOTTING_H_

Expand Down Expand Up @@ -55,7 +54,6 @@ class KeywordSpotting {
std::vector<float> cache_;
};


} // namespace wekws

#endif // KWS_KEYWORD_SPOTTING_H_
54 changes: 27 additions & 27 deletions runtime/core/utils/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.


#ifndef UTILS_LOG_H_
#define UTILS_LOG_H_

Expand All @@ -30,21 +29,21 @@ class Logger {
Logger(int severity, const char* func, const char* file, int line) {
severity_ = severity;
switch (severity) {
case INFO:
ss_ << "INFO (";
break;
case WARNING:
ss_ << "WARNING (";
break;
case ERROR:
ss_ << "ERROR (";
break;
case FATAL:
ss_ << "FATAL (";
break;
default:
severity_ = FATAL;
ss_ << "FATAL (";
case INFO:
ss_ << "INFO (";
break;
case WARNING:
ss_ << "WARNING (";
break;
case ERROR:
ss_ << "ERROR (";
break;
case FATAL:
ss_ << "FATAL (";
break;
default:
severity_ = FATAL;
ss_ << "FATAL (";
}
ss_ << func << "():" << file << ':' << line << ") ";
}
Expand All @@ -56,7 +55,8 @@ class Logger {
}
}

template <typename T> Logger& operator<<(const T &val) {
template <typename T>
Logger& operator<<(const T& val) {
ss_ << val;
return *this;
}
Expand All @@ -66,17 +66,17 @@ class Logger {
std::ostringstream ss_;
};

#define LOG(severity) ::wenet::Logger( \
::wenet::severity, __func__, __FILE__, __LINE__)
#define LOG(severity) \
::wenet::Logger(::wenet::severity, __func__, __FILE__, __LINE__)

#define CHECK(test) \
do { \
if (!(test)) { \
std::cerr << "CHECK (" << __func__ << "():" << __FILE__ << ":" \
<< __LINE__ << ") " << #test << std::endl; \
exit(-1); \
} \
} while (0)
#define CHECK(test) \
do { \
if (!(test)) { \
std::cerr << "CHECK (" << __func__ << "():" << __FILE__ << ":" \
<< __LINE__ << ") " << #test << std::endl; \
exit(-1); \
} \
} while (0)

} // namespace wenet

Expand Down
5 changes: 3 additions & 2 deletions tools/compute_cmvn_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class CollateFunc(object):
''' Collate function for AudioDataset
'''

def __init__(self, feat_dim, feat_type, resample_rate):
self.feat_dim = feat_dim
self.resample_rate = resample_rate
Expand All @@ -30,8 +31,7 @@ def __call__(self, batch):
value = item[1].strip().split(",")
assert len(value) == 3 or len(value) == 1
wav_path = value[0]
sample_rate = torchaudio.info(wav_path,
backend='sox').sample_rate
sample_rate = torchaudio.info(wav_path, backend='sox').sample_rate
resample_rate = sample_rate
# len(value) == 3 means segmented wav.scp,
# len(value) == 1 means original wav.scp
Expand Down Expand Up @@ -73,6 +73,7 @@ def __call__(self, batch):


class AudioDataset(Dataset):

def __init__(self, data_file):
self.items = []
with codecs.open(data_file, 'r', encoding='utf-8') as f:
Expand Down
Loading
Loading