Skip to content

Commit

Permalink
refactor: Update whisper model path handling in transcription filter
Browse files Browse the repository at this point in the history
  • Loading branch information
royshil committed Jun 21, 2024
1 parent d64ec2a commit ccd0c61
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 23 deletions.
4 changes: 3 additions & 1 deletion src/model-utils/model-downloader-ui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ ModelDownloader::~ModelDownloader()
}
delete this->download_thread;
}
delete this->download_worker;
if (this->download_worker != nullptr) {
delete this->download_worker;
}
}

ModelDownloadWorker::~ModelDownloadWorker()
Expand Down
5 changes: 1 addition & 4 deletions src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,7 @@ void enable_callback(void *data_, calldata_t *cd)
obs_log(gf_->log_level, "enable_callback: enable");
gf_->active = true;
reset_caption_state(gf_);
// get filter settings from gf_->context
obs_data_t *settings = obs_source_get_settings(gf_->context);
update_whisper_model(gf_, settings);
obs_data_release(settings);
update_whisper_model(gf_);
} else {
obs_log(gf_->log_level, "enable_callback: disable");
gf_->active = false;
Expand Down
16 changes: 8 additions & 8 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);

gf->log_level = LOG_INFO; //(int)obs_data_get_int(s, "log_level");
gf->log_level = (int)obs_data_get_int(s, "log_level");
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
gf->log_words = obs_data_get_bool(s, "log_words");
gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream");
Expand Down Expand Up @@ -216,9 +216,9 @@ void transcription_filter_update(void *data, obs_data_t *s)
}
}
} else {
obs_log(LOG_INFO, "buffered_output disable");
obs_log(gf->log_level, "buffered_output disable");
if (gf->buffered_output) {
obs_log(LOG_INFO, "buffered_output currently enabled, disabling");
obs_log(gf->log_level, "buffered_output currently enabled, disabling");
if (gf->captions_monitor.isEnabled()) {
gf->captions_monitor.clear();
gf->captions_monitor.stopThread();
Expand Down Expand Up @@ -345,11 +345,11 @@ void transcription_filter_update(void *data, obs_data_t *s)
}
}

if (gf->initial_creation && obs_source_enabled(gf->context)) {
if (gf->initial_creation && gf->context != nullptr && obs_source_enabled(gf->context)) {
obs_log(LOG_INFO, "Initial filter creation and source enabled");

// source was enabled on creation
obs_data_t *settings = obs_source_get_settings(gf->context);
update_whisper_model(gf, settings);
obs_data_release(settings);
update_whisper_model(gf);
gf->active = true;
gf->initial_creation = false;
}
Expand Down Expand Up @@ -497,7 +497,7 @@ void transcription_filter_hide(void *data)

void transcription_filter_defaults(obs_data_t *s)
{
obs_log(LOG_INFO, "filter defaults");
obs_log(LOG_DEBUG, "filter defaults");

obs_data_set_default_bool(s, "buffered_output", false);
obs_data_set_default_int(s, "buffer_num_lines", 2);
Expand Down
28 changes: 21 additions & 7 deletions src/whisper-utils/whisper-model-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,28 @@
#include "plugin-support.h"
#include "model-utils/model-downloader.h"

void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
void update_whisper_model(struct transcription_filter_data *gf)
{
// update the whisper model path
if (gf->context == nullptr) {
obs_log(LOG_ERROR, "obs_source_t context is null");
return;
}

obs_data_t *s = obs_source_get_settings(gf->context);
if (s == nullptr) {
obs_log(LOG_ERROR, "obs_data_t settings is null");
return;
}

// Get settings from context
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
obs_data_release(s);

// update the whisper model path

const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;

char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx");
Expand Down Expand Up @@ -73,8 +91,6 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
}
} else {
// new model is external file, get file location from file property
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
if (external_model_file_path.empty()) {
obs_log(LOG_WARNING, "External model file path is empty");
} else {
Expand All @@ -98,13 +114,11 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
gf->whisper_model_path.c_str(), new_model_path.c_str());
}

const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");

if (new_dtw_timestamps != gf->enable_token_ts_dtw) {
// dtw_token_timestamps changed
obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d",
gf->enable_token_ts_dtw, new_dtw_timestamps);
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
gf->enable_token_ts_dtw = new_dtw_timestamps;
shutdown_whisper_thread(gf);
start_whisper_thread_with_path(gf, gf->whisper_model_path,
silero_vad_model_file_str.c_str());
Expand Down
2 changes: 1 addition & 1 deletion src/whisper-utils/whisper-model-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

#include "transcription-filter-data.h"

void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s);
void update_whisper_model(struct transcription_filter_data *gf);

#endif // WHISPER_MODEL_UTILS_H
4 changes: 2 additions & 2 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ void whisper_loop(void *data)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);

obs_log(LOG_INFO, "starting whisper thread");
obs_log(gf->log_level, "Starting whisper thread");

vad_state current_vad_state = {false, 0, 0};
// 500 ms worth of audio is needed for VAD segmentation
Expand Down Expand Up @@ -511,5 +511,5 @@ void whisper_loop(void *data)
gf->wshiper_thread_cv.wait_for(lock, std::chrono::milliseconds(50));
}

obs_log(LOG_INFO, "exiting whisper thread");
obs_log(gf->log_level, "Exiting whisper thread");
}

0 comments on commit ccd0c61

Please sign in to comment.