Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track token scores #571

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
build
*.zip
*.tgz
*.sw?
onnxruntime-*
icefall-*
run.sh
Expand Down
14 changes: 13 additions & 1 deletion sherpa-onnx/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,21 @@ struct Hypothesis {
std::vector<int32_t> timestamps;

// The acoustic probability for each token in ys.
// Only used for keyword spotting task.
// Used for keyword spotting task.
// For transducer mofified beam-search and greedy-search,
// this is filled with log_posterior scores.
std::vector<float> ys_probs;

// lm_probs[i] contains the lm score for each token in ys.
// Used only in transducer mofified beam-search.
// Elements filled only if LM is used.
std::vector<float> lm_probs;

// context_scores[i] contains the context-graph score for each token in ys.
// Used only in transducer mofified beam-search.
// Elements filled only if `ContextGraph` is used.
std::vector<float> context_scores;

// The total score of ys in log space.
// It contains only acoustic scores
double log_prob = 0;
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.timestamps.push_back(time);
}

r.ys_probs = std::move(src.ys_probs);
r.lm_probs = std::move(src.lm_probs);
r.context_scores = std::move(src.context_scores);

r.segment = segment;
r.start_time = frames_since_start * frame_shift_ms / 1000.;

Expand Down
82 changes: 38 additions & 44 deletions sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,50 @@

namespace sherpa_onnx {

std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{";
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
os << "\"segment\":" << segment << ", ";
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
<< ", ";

os << "\"text\""
<< ": ";
os << "\"" << text << "\""
<< ", ";

os << "\""
<< "timestamps"
<< "\""
<< ": ";
os << "[";

/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<typename T>
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ ";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
for (const auto& item : vec) {
oss << sep << item;
sep = ", ";
}
os << "], ";

os << "\""
<< "tokens"
<< "\""
<< ":";
os << "[";

sep = "";
auto oldFlags = os.flags();
for (const auto &t : tokens) {
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
os << sep << "\""
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
<< ">"
<< "\"";
os.flags(oldFlags);
} else {
os << sep << "\"" << t << "\"";
}
oss << " ]";
return oss.str();
}

/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<> // explicit specialization for T = std::string
std::string VecToString<std::string>(const std::vector<std::string>& vec,
int32_t) { // ignore 2nd arg
std::ostringstream oss;
oss << "[ ";
std::string sep = "";
for (const auto& item : vec) {
oss << sep << "\"" << item << "\"";
sep = ", ";
}
os << "]";
os << "}";
oss << " ]";
return oss.str();
}

std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{ ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"segment\": " << segment << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2)
<< start_time << ", ";
os << "\"is_final\": " << (is_final ? "true" : "false");
os << "}";
return os.str();
}

Expand Down
9 changes: 9 additions & 0 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ struct OnlineRecognizerResult {
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;

std::vector<float> ys_probs; //< log-prob scores from ASR model
std::vector<float> lm_probs; //< log-prob scores from language model
//
/// log-domain scores from "hot-phrase" contextual boosting
std::vector<float> context_scores;

/// ID of this segment
/// When an endpoint is detected, it is incremented
int32_t segment = 0;
Expand All @@ -58,6 +64,9 @@ struct OnlineRecognizerResult {
* "text": "The recognition result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "ys_probs": [x, x, x],
* "lm_probs": [x, x, x],
* "context_scores": [x, x, x],
* "segment": x,
* "start_time": x,
* "is_final": true|false
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/csrc/online-transducer-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
frame_offset = other.frame_offset;
timestamps = other.timestamps;

ys_probs = other.ys_probs;
lm_probs = other.lm_probs;
context_scores = other.context_scores;

return *this;
}

Expand All @@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
frame_offset = other.frame_offset;
timestamps = std::move(other.timestamps);

ys_probs = std::move(other.ys_probs);
lm_probs = std::move(other.lm_probs);
context_scores = std::move(other.context_scores);

return *this;
}

Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/online-transducer-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult {
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std::vector<int32_t> timestamps;

std::vector<float> ys_probs;
std::vector<float> lm_probs;
std::vector<float> context_scores;

// Cache decoder_out for endpointing
Ort::Value decoder_out;

Expand Down
15 changes: 15 additions & 0 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
r->tokens = std::vector<int64_t>(start, end);
}


void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {

std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();

Expand All @@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
break;
}
}

if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape =
Expand Down Expand Up @@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}

auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
Expand All @@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode(
} else {
++r.num_trailing_blanks;
}

// export the per-token log scores
if (y != 0 && y != unk_id_) {
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
// save time by doing it only for
// emitted symbols
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
// now it contains normalized
// probability
r.ys_probs.push_back(p_logprob[y]);
}
}
if (emitted) {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Expand Down
28 changes: 28 additions & 0 deletions sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
r->tokens = std::move(tokens);
r->timestamps = std::move(hyp.timestamps);

// export per-token scores
r->ys_probs = std::move(hyp.ys_probs);
r->lm_probs = std::move(hyp.lm_probs);
r->context_scores = std::move(hyp.context_scores);

r->num_trailing_blanks = hyp.num_trailing_blanks;
}

Expand Down Expand Up @@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
new_hyp.log_prob = p_logprob[k] + context_score -
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis& prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// getting topk tokens
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
new_hyp.ys_probs.push_back(y_prob);

if (lm_) { // export only when LM is used
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
}
new_hyp.lm_probs.push_back(lm_prob);
}

// export only when `ContextGraph` is used
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
new_hyp.context_scores.push_back(context_score);
}
}

hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
cur.push_back(std::move(hyps));
Expand Down
21 changes: 20 additions & 1 deletion sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) {
[](PyClass &self) -> float { return self.start_time; })
.def_property_readonly(
"timestamps",
[](PyClass &self) -> std::vector<float> { return self.timestamps; });
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
.def_property_readonly(
"ys_probs",
[](PyClass &self) -> std::vector<float> { return self.ys_probs; })
.def_property_readonly(
"lm_probs",
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
.def_property_readonly(
"context_scores",
[](PyClass &self) -> std::vector<float> {
return self.context_scores;
})
.def_property_readonly(
"segment",
[](PyClass &self) -> int32_t { return self.segment; })
.def_property_readonly(
"is_final",
[](PyClass &self) -> bool { return self.is_final; })
.def("as_json_string", &PyClass::AsJsonString,
py::call_guard<py::gil_scoped_release>());
}

static void PybindOnlineRecognizerConfig(py::module *m) {
Expand Down
12 changes: 12 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ def is_ready(self, s: OnlineStream) -> bool:
def get_result(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).text.strip()

def get_result_as_json_string(self, s: OnlineStream) -> str:
return self.recognizer.get_result(s).as_json_string()

def tokens(self, s: OnlineStream) -> List[str]:
return self.recognizer.get_result(s).tokens

Expand All @@ -512,6 +515,15 @@ def timestamps(self, s: OnlineStream) -> List[float]:
def start_time(self, s: OnlineStream) -> float:
return self.recognizer.get_result(s).start_time

def ys_probs(self, s: OnlineStream) -> List[float]:
return self.recognizer.get_result(s).ys_probs

def lm_probs(self, s: OnlineStream) -> List[float]:
return self.recognizer.get_result(s).lm_probs

def context_scores(self, s: OnlineStream) -> List[float]:
return self.recognizer.get_result(s).context_scores

def is_endpoint(self, s: OnlineStream) -> bool:
return self.recognizer.is_endpoint(s)

Expand Down
Loading