diff --git a/.gitignore b/.gitignore index 47d10a97ac..06834c3c6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ build *.zip *.tgz +*.sw? onnxruntime-* icefall-* run.sh diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 4e534c3754..ada88c8cec 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -29,9 +29,21 @@ struct Hypothesis { std::vector timestamps; // The acoustic probability for each token in ys. - // Only used for keyword spotting task. + // Used for keyword spotting task. + // For transducer mofified beam-search and greedy-search, + // this is filled with log_posterior scores. std::vector ys_probs; + // lm_probs[i] contains the lm score for each token in ys. + // Used only in transducer mofified beam-search. + // Elements filled only if LM is used. + std::vector lm_probs; + + // context_scores[i] contains the context-graph score for each token in ys. + // Used only in transducer mofified beam-search. + // Elements filled only if `ContextGraph` is used. + std::vector context_scores; + // The total score of ys in log space. // It contains only acoustic scores double log_prob = 0; diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index b3f31cdd0e..f31a18d473 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, r.timestamps.push_back(time); } + r.ys_probs = std::move(src.ys_probs); + r.lm_probs = std::move(src.lm_probs); + r.context_scores = std::move(src.context_scores); + r.segment = segment; r.start_time = frames_since_start * frame_shift_ms / 1000.; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 9cf7930627..ea7e9f9056 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -18,56 +18,50 @@ namespace sherpa_onnx { -std::string OnlineRecognizerResult::AsJsonString() const { - std::ostringstream os; - os << "{"; - os << "\"is_final\":" << (is_final ? "true" : "false") << ", "; - os << "\"segment\":" << segment << ", "; - os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time - << ", "; - - os << "\"text\"" - << ": "; - os << "\"" << text << "\"" - << ", "; - - os << "\"" - << "timestamps" - << "\"" - << ": "; - os << "["; - +/// Helper for `OnlineRecognizerResult::AsJsonString()` +template +std::string VecToString(const std::vector& vec, int32_t precision = 6) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(precision); + oss << "[ "; std::string sep = ""; - for (auto t : timestamps) { - os << sep << std::fixed << std::setprecision(2) << t; + for (const auto& item : vec) { + oss << sep << item; sep = ", "; } - os << "], "; - - os << "\"" - << "tokens" - << "\"" - << ":"; - os << "["; - - sep = ""; - auto oldFlags = os.flags(); - for (const auto &t : tokens) { - if (t.size() == 1 && static_cast(t[0]) > 0x7f) { - const uint8_t *p = reinterpret_cast(t.c_str()); - os << sep << "\"" - << "<0x" << std::hex << std::uppercase << static_cast(p[0]) - << ">" - << "\""; - os.flags(oldFlags); - } else { - os << sep << "\"" << t << "\""; - } + oss << " ]"; + return oss.str(); +} + +/// Helper for `OnlineRecognizerResult::AsJsonString()` +template<> // explicit specialization for T = std::string +std::string VecToString(const std::vector& vec, + int32_t) { // ignore 2nd arg + std::ostringstream oss; + oss << "[ "; + std::string sep = ""; + for (const auto& item : vec) { + oss << sep << "\"" << item << "\""; sep = ", "; } - os << "]"; - os << "}"; + oss << " ]"; + return oss.str(); +} +std::string OnlineRecognizerResult::AsJsonString() const { + std::ostringstream os; + os << "{ "; + os << "\"text\": " << "\"" << text << "\"" << ", "; + os << "\"tokens\": " << VecToString(tokens) << ", "; + os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; + os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; + os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; + os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; + os << "\"segment\": " << segment << ", "; + os << "\"start_time\": " << std::fixed << std::setprecision(2) + << start_time << ", "; + os << "\"is_final\": " << (is_final ? "true" : "false"); + os << "}"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index f0580d9cf6..ec8875e680 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -40,6 +40,12 @@ struct OnlineRecognizerResult { /// timestamps[i] records the time in seconds when tokens[i] is decoded. std::vector timestamps; + std::vector ys_probs; //< log-prob scores from ASR model + std::vector lm_probs; //< log-prob scores from language model + // + /// log-domain scores from "hot-phrase" contextual boosting + std::vector context_scores; + /// ID of this segment /// When an endpoint is detected, it is incremented int32_t segment = 0; @@ -58,6 +64,9 @@ struct OnlineRecognizerResult { * "text": "The recognition result", * "tokens": [x, x, x], * "timestamps": [x, x, x], + * "ys_probs": [x, x, x], + * "lm_probs": [x, x, x], + * "context_scores": [x, x, x], * "segment": x, * "start_time": x, * "is_final": true|false diff --git a/sherpa-onnx/csrc/online-transducer-decoder.cc b/sherpa-onnx/csrc/online-transducer-decoder.cc index 7a1e5a4332..0c51eabb53 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-decoder.cc @@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( frame_offset = other.frame_offset; timestamps = other.timestamps; + ys_probs = other.ys_probs; + lm_probs = other.lm_probs; + context_scores = other.context_scores; + return *this; } @@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( frame_offset = other.frame_offset; timestamps = std::move(other.timestamps); + ys_probs = std::move(other.ys_probs); + lm_probs = std::move(other.lm_probs); + context_scores = std::move(other.context_scores); + return *this; } diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index 68a8fae43b..6265366044 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult { /// timestamps[i] contains the output frame index where tokens[i] is decoded. std::vector timestamps; + std::vector ys_probs; + std::vector lm_probs; + std::vector context_scores; + // Cache decoder_out for endpointing Ort::Value decoder_out; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index e79a8e2f4f..c026e28a49 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( r->tokens = std::vector(start, end); } + void OnlineTransducerGreedySearchDecoder::Decode( Ort::Value encoder_out, std::vector *result) { + std::vector encoder_out_shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); @@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( break; } } + if (is_batch_decoder_out_cached) { auto &r = result->front(); std::vector decoder_out_shape = @@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( if (blank_penalty_ > 0.0) { p_logit[0] -= blank_penalty_; // assuming blank id is 0 } + auto y = static_cast(std::distance( static_cast(p_logit), std::max_element(static_cast(p_logit), @@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode( } else { ++r.num_trailing_blanks; } + + // export the per-token log scores + if (y != 0 && y != unk_id_) { + LogSoftmax(p_logit, vocab_size); // renormalize probabilities, + // save time by doing it only for + // emitted symbols + const float *p_logprob = p_logit; // rename p_logit as p_logprob, + // now it contains normalized + // probability + r.ys_probs.push_back(p_logprob[y]); + } } if (emitted) { Ort::Value decoder_input = model_->BuildDecoderInput(*result); diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index f676f45da9..e37ba63d48 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( std::vector tokens(hyp.ys.begin() + context_size, hyp.ys.end()); r->tokens = std::move(tokens); r->timestamps = std::move(hyp.timestamps); + + // export per-token scores + r->ys_probs = std::move(hyp.ys_probs); + r->lm_probs = std::move(hyp.lm_probs); + r->context_scores = std::move(hyp.context_scores); + r->num_trailing_blanks = hyp.num_trailing_blanks; } @@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( new_hyp.log_prob = p_logprob[k] + context_score - prev_lm_log_prob; // log_prob only includes the // score of the transducer + // export the per-token log scores + if (new_token != 0 && new_token != unk_id_) { + const Hypothesis& prev_i = prev[hyp_index]; + // subtract 'prev[i]' path scores, which were added before + // getting topk tokens + float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; + new_hyp.ys_probs.push_back(y_prob); + + if (lm_) { // export only when LM is used + float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob; + if (lm_scale_ != 0.0) { + lm_prob /= lm_scale_; // remove lm-scale + } + new_hyp.lm_probs.push_back(lm_prob); + } + + // export only when `ContextGraph` is used + if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { + new_hyp.context_scores.push_back(context_score); + } + } + hyps.Add(std::move(new_hyp)); } // for (auto k : topk) cur.push_back(std::move(hyps)); diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index c10cb3b4f0..0213bd7b27 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) { [](PyClass &self) -> float { return self.start_time; }) .def_property_readonly( "timestamps", - [](PyClass &self) -> std::vector { return self.timestamps; }); + [](PyClass &self) -> std::vector { return self.timestamps; }) + .def_property_readonly( + "ys_probs", + [](PyClass &self) -> std::vector { return self.ys_probs; }) + .def_property_readonly( + "lm_probs", + [](PyClass &self) -> std::vector { return self.lm_probs; }) + .def_property_readonly( + "context_scores", + [](PyClass &self) -> std::vector { + return self.context_scores; + }) + .def_property_readonly( + "segment", + [](PyClass &self) -> int32_t { return self.segment; }) + .def_property_readonly( + "is_final", + [](PyClass &self) -> bool { return self.is_final; }) + .def("as_json_string", &PyClass::AsJsonString, + py::call_guard()); } static void PybindOnlineRecognizerConfig(py::module *m) { diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 70de7a2aae..6d95cd6dd1 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -503,6 +503,9 @@ def is_ready(self, s: OnlineStream) -> bool: def get_result(self, s: OnlineStream) -> str: return self.recognizer.get_result(s).text.strip() + def get_result_as_json_string(self, s: OnlineStream) -> str: + return self.recognizer.get_result(s).as_json_string() + def tokens(self, s: OnlineStream) -> List[str]: return self.recognizer.get_result(s).tokens @@ -512,6 +515,15 @@ def timestamps(self, s: OnlineStream) -> List[float]: def start_time(self, s: OnlineStream) -> float: return self.recognizer.get_result(s).start_time + def ys_probs(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).ys_probs + + def lm_probs(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).lm_probs + + def context_scores(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).context_scores + def is_endpoint(self, s: OnlineStream) -> bool: return self.recognizer.is_endpoint(s)