Skip to content

Commit

Permalink
Only reset VAD state between chunks of activity
Browse files Browse the repository at this point in the history
  • Loading branch information
palana committed Aug 2, 2024
1 parent 1ed62ca commit 606938d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
14 changes: 8 additions & 6 deletions src/whisper-utils/silero-vad-onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ void VadIterator::init_onnx_model(const SileroString &model_path)
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
};

void VadIterator::reset_states()
void VadIterator::reset_states(bool reset_state)
{
// Call reset before each audio start
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
triggered = false;
if (reset_state) {
// Call reset before each audio start
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
triggered = false;
}
temp_end = 0;
current_sample = 0;

Expand Down Expand Up @@ -257,9 +259,9 @@ void VadIterator::predict(const std::vector<float> &data)
}
};

void VadIterator::process(const std::vector<float> &input_wav)
void VadIterator::process(const std::vector<float> &input_wav, bool reset_state)
{
reset_states();
reset_states(reset_state);

audio_length_samples = (int)input_wav.size();

Expand Down
4 changes: 2 additions & 2 deletions src/whisper-utils/silero-vad-onnx.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ class VadIterator {
private:
void init_engine_threads(int inter_threads, int intra_threads);
void init_onnx_model(const SileroString &model_path);
void reset_states();
void reset_states(bool reset_state);
float predict_one(const std::vector<float> &data);
void predict(const std::vector<float> &data);

public:
void process(const std::vector<float> &input_wav);
void process(const std::vector<float> &input_wav, bool reset_state = true);
void process(const std::vector<float> &input_wav, std::vector<float> &output_wav);
void collect_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
const std::vector<timestamp_t> get_speech_timestamps() const;
Expand Down
2 changes: 1 addition & 1 deletion src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v
obs_log(gf->log_level, "sending %d frames to vad", vad_input.size());
{
ProfileScope("vad->process");
gf->vad->process(vad_input);
gf->vad->process(vad_input, !last_vad_state.vad_on);
}

const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000;
Expand Down

0 comments on commit 606938d

Please sign in to comment.