diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index a94a1ff673..30f6462167 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx display.cc endpoint.cc features.cc + keyword-spotter.cc offline-ctc-fst-decoder-config.cc offline-lm-config.cc offline-model-config.cc diff --git a/sherpa-onnx/python/csrc/keyword-spotter.cc b/sherpa-onnx/python/csrc/keyword-spotter.cc new file mode 100644 index 0000000000..144992605d --- /dev/null +++ b/sherpa-onnx/python/csrc/keyword-spotter.cc @@ -0,0 +1,82 @@ +// sherpa-onnx/python/csrc/keyword-spotter.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/keyword-spotter.h" + +#include +#include + +#include "sherpa-onnx/csrc/keyword-spotter.h" + +namespace sherpa_onnx { + +static void PybindKeywordResult(py::module *m) { + using PyClass = KeywordResult; + py::class_(*m, "KeywordResult") + .def_property_readonly( + "keyword", + [](PyClass &self) -> py::str { + return py::str(PyUnicode_DecodeUTF8(self.keyword.c_str(), + self.keyword.size(), "ignore")); + }) + .def_property_readonly( + "tokens", + [](PyClass &self) -> std::vector { return self.tokens; }) + .def_property_readonly( + "timestamps", + [](PyClass &self) -> std::vector { return self.timestamps; }); +} + +static void PybindKeywordSpotterConfig(py::module *m) { + using PyClass = KeywordSpotterConfig; + py::class_(*m, "KeywordSpotterConfig") + .def(py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("max_active_paths") = 4, py::arg("num_trailing_blanks") = 1, + py::arg("keywords_score") = 1.0, + py::arg("keywords_threshold") = 0.25, py::arg("keywords_file") = "") + .def_readwrite("feat_config", &PyClass::feat_config) + .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("max_active_paths", &PyClass::max_active_paths) + .def_readwrite("num_trailing_blanks", &PyClass::num_trailing_blanks) + .def_readwrite("keywords_score", &PyClass::keywords_score) + .def_readwrite("keywords_threshold", &PyClass::keywords_threshold) + .def_readwrite("keywords_file", &PyClass::keywords_file) + .def("__str__", &PyClass::ToString); +} + +void PybindKeywordSpotter(py::module *m) { + PybindKeywordResult(m); + PybindKeywordSpotterConfig(m); + + using PyClass = KeywordSpotter; + py::class_(*m, "KeywordSpotter") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def( + "create_stream", + [](const PyClass &self) { return self.CreateStream(); }, + py::call_guard()) + .def( + "create_stream", + [](PyClass &self, const std::string &keywords) { + return self.CreateStream(keywords); + }, + py::arg("keywords"), py::call_guard()) + .def("is_ready", &PyClass::IsReady, + py::call_guard()) + .def("decode_stream", &PyClass::DecodeStream, + py::call_guard()) + .def( + "decode_streams", + [](PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }, + py::call_guard()) + .def("get_result", &PyClass::GetResult, + py::call_guard()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/keyword-spotter.h b/sherpa-onnx/python/csrc/keyword-spotter.h new file mode 100644 index 0000000000..dce0bae02a --- /dev/null +++ b/sherpa-onnx/python/csrc/keyword-spotter.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/keyword-spotter.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindKeywordSpotter(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 37728426d0..bdc38bbe9c 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -8,6 +8,7 @@ #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" +#include "sherpa-onnx/python/csrc/keyword-spotter.h" #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/offline-lm-config.h" #include "sherpa-onnx/python/csrc/offline-model-config.h" @@ -35,6 +36,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOnlineStream(&m); PybindEndpoint(&m); PybindOnlineRecognizer(&m); + PybindKeywordSpotter(&m); PybindDisplay(&m);