Skip to content

Commit

Permalink
extend per-token scores to greedy-search
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelVesely84 committed Feb 14, 2024
1 parent b14610a commit 5190e9b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
16 changes: 16 additions & 0 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
r->tokens = std::vector<int64_t>(start, end);
}


void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {

std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();

Expand All @@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
break;
}
}

if (is_batch_decoder_out_cached) {
auto &r = result->front();
std::vector<int64_t> decoder_out_shape =
Expand Down Expand Up @@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}

auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
Expand All @@ -138,6 +142,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
} else {
++r.num_trailing_blanks;
}

// export the per-token log scores
if (y != 0 && y != unk_id_) {
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
// save time by doing it only for
// emitted symbols
float *p_logprob = p_logit; // rename p_logit as p_logprob,
// now it contains normalized
// probability
r.ys_probs.push_back(p_logprob[y]);
}

}
if (emitted) {
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
p_logprob = p_logit; // we changed p_logprob in the above for loop

// KarelVesely: Sholud the context score be added already before taking topk tokens ?

for (int32_t b = 0; b != batch_size; ++b) {
int32_t frame_offset = (*result)[b].frame_offset;
int32_t start = hyps_row_splits[b];
Expand Down Expand Up @@ -190,7 +188,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
// export the per-token log scores
{
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis& prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// for getting topk tokens
Expand Down

0 comments on commit 5190e9b

Please sign in to comment.