Skip to content

Commit

Permalink
Start and stop based on filter enable status (#111)
Browse files Browse the repository at this point in the history
* refactor: Add initial_creation flag to transcription filter data

* refactor: Improve caption duration calculation in set_text_callback
  • Loading branch information
royshil authored Jun 11, 2024
1 parent 91c2842 commit 2aa151e
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 58 deletions.
32 changes: 29 additions & 3 deletions src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "transcription-utils.h"
#include "translation/translation.h"
#include "translation/translation-includes.h"

#define SEND_TIMED_METADATA_URL "http://localhost:8080/timed-metadata"
#include "whisper-utils/whisper-utils.h"
#include "whisper-utils/whisper-model-utils.h"

void send_caption_to_source(const std::string &target_source_name, const std::string &caption,
struct transcription_filter_data *gf)
Expand Down Expand Up @@ -130,7 +130,13 @@ void set_text_callback(struct transcription_filter_data *gf,
if (gf->caption_to_stream) {
obs_output_t *streaming_output = obs_frontend_get_streaming_output();
if (streaming_output) {
obs_output_output_caption_text1(streaming_output, str_copy.c_str());
// calculate the duration in seconds
const uint64_t duration =
result.end_timestamp_ms - result.start_timestamp_ms;
obs_log(gf->log_level, "Sending caption to streaming output: %s",
str_copy.c_str());
obs_output_output_caption_text2(streaming_output, str_copy.c_str(),
(double)duration / 1000.0);
obs_output_release(streaming_output);
}
}
Expand Down Expand Up @@ -285,3 +291,23 @@ void media_stopped_callback(void *data_, calldata_t *cd)
gf_->active = false;
reset_caption_state(gf_);
}

void enable_callback(void *data_, calldata_t *cd)
{
transcription_filter_data *gf_ = static_cast<struct transcription_filter_data *>(data_);
bool enable = calldata_bool(cd, "enabled");
if (enable) {
obs_log(gf_->log_level, "enable_callback: enable");
gf_->active = true;
reset_caption_state(gf_);
// get filter settings from gf_->context
obs_data_t *settings = obs_source_get_settings(gf_->context);
update_whisper_model(gf_, settings);
obs_data_release(settings);
} else {
obs_log(gf_->log_level, "enable_callback: disable");
gf_->active = false;
reset_caption_state(gf_);
shutdown_whisper_thread(gf_);
}
}
1 change: 1 addition & 0 deletions src/transcription-filter-callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ void media_started_callback(void *data_, calldata_t *cd);
void media_pause_callback(void *data_, calldata_t *cd);
void media_restart_callback(void *data_, calldata_t *cd);
void media_stopped_callback(void *data_, calldata_t *cd);
void enable_callback(void *data_, calldata_t *cd);

#endif /* TRANSCRIPTION_FILTER_CALLBACKS_H */
1 change: 1 addition & 0 deletions src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ struct transcription_filter_data {
bool fix_utf8 = true;
bool enable_audio_chunks_callback = false;
bool source_signals_set = false;
bool initial_creation = true;

// Last transcription result
std::string last_text;
Expand Down
2 changes: 2 additions & 0 deletions src/transcription-filter.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ struct obs_source_info transcription_filter_info = {
.deactivate = transcription_filter_deactivate,
.filter_audio = transcription_filter_filter_audio,
.filter_remove = transcription_filter_remove,
.show = transcription_filter_show,
.hide = transcription_filter_hide,
};
119 changes: 74 additions & 45 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_
if (!audio) {
return nullptr;
}

if (data == nullptr) {
return audio;
}
Expand Down Expand Up @@ -137,6 +138,9 @@ void transcription_filter_destroy(void *data)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);

signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context);
signal_handler_disconnect(sh_filter, "enable", enable_callback, gf);

obs_log(gf->log_level, "filter destroy");
shutdown_whisper_thread(gf);

Expand Down Expand Up @@ -167,7 +171,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);

gf->log_level = (int)obs_data_get_int(s, "log_level");
gf->log_level = LOG_INFO; //(int)obs_data_get_int(s, "log_level");
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
gf->log_words = obs_data_get_bool(s, "log_words");
gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream");
Expand Down Expand Up @@ -293,51 +297,61 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->text_source_name = new_text_source_name;
}

obs_log(gf->log_level, "update whisper model");
update_whisper_model(gf, s);

obs_log(gf->log_level, "update whisper params");
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
{
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);

gf->sentence_psum_accept_thresh =
(float)obs_data_get_double(s, "sentence_psum_accept_thresh");
gf->sentence_psum_accept_thresh =
(float)obs_data_get_double(s, "sentence_psum_accept_thresh");

gf->whisper_params = whisper_full_default_params(
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
} else {
// take the language from gf->target_lang
gf->whisper_params.language = language_codes_2_reverse[gf->target_lang].c_str();
gf->whisper_params = whisper_full_default_params(
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
gf->whisper_params.language =
obs_data_get_string(s, "whisper_language_select");
} else {
// take the language from gf->target_lang
gf->whisper_params.language =
language_codes_2_reverse[gf->target_lang].c_str();
}
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate");
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress");
gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime");
gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps");
gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps");
gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt");
gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum");
gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len");
gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word");
gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens");
gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up");
gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank");
gf->whisper_params.suppress_non_speech_tokens =
obs_data_get_bool(s, "suppress_non_speech_tokens");
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");

if (gf->vad_enabled && gf->vad) {
const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold");
gf->vad->set_threshold(vad_threshold);
}
}
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate");
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress");
gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime");
gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps");
gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps");
gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt");
gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum");
gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len");
gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word");
gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens");
gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up");
gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank");
gf->whisper_params.suppress_non_speech_tokens =
obs_data_get_bool(s, "suppress_non_speech_tokens");
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");

if (gf->vad_enabled && gf->vad) {
const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold");
gf->vad->set_threshold(vad_threshold);

if (gf->initial_creation && obs_source_enabled(gf->context)) {
// source was enabled on creation
obs_data_t *settings = obs_source_get_settings(gf->context);
update_whisper_model(gf, settings);
obs_data_release(settings);
gf->active = true;
gf->initial_creation = false;
}
}

Expand Down Expand Up @@ -421,12 +435,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
gf->whisper_model_path = std::string(""); // The update function will set the model path
gf->whisper_context = nullptr;

signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context);
signal_handler_connect(sh_filter, "enable", enable_callback, gf);

obs_log(gf->log_level, "run update");
// get the settings updated on the filter data struct
transcription_filter_update(gf, settings);

gf->active = true;

// handle the event OBS_FRONTEND_EVENT_RECORDING_STARTING to reset the srt sentence number
// to match the subtitles with the recording
obs_frontend_add_event_callback(recording_state_callback, gf);
Expand Down Expand Up @@ -466,6 +481,20 @@ void transcription_filter_deactivate(void *data)
gf->active = false;
}

void transcription_filter_show(void *data)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
obs_log(gf->log_level, "filter show");
}

void transcription_filter_hide(void *data)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
obs_log(gf->log_level, "filter hide");
}

void transcription_filter_defaults(obs_data_t *s)
{
obs_log(LOG_INFO, "filter defaults");
Expand Down Expand Up @@ -586,11 +615,11 @@ obs_properties_t *transcription_filter_properties(void *data)
whisper_model_path_external,
[](void *data_, obs_properties_t *props, obs_property_t *property,
obs_data_t *settings) {
obs_log(LOG_INFO, "whisper_model_path_external modified");
UNUSED_PARAMETER(property);
UNUSED_PARAMETER(props);
struct transcription_filter_data *gf_ =
static_cast<struct transcription_filter_data *>(data_);
obs_log(gf_->log_level, "whisper_model_path_external modified");
transcription_filter_update(gf_, settings);
return true;
},
Expand Down
2 changes: 2 additions & 0 deletions src/transcription-filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ void transcription_filter_deactivate(void *data);
void transcription_filter_defaults(obs_data_t *s);
obs_properties_t *transcription_filter_properties(void *data);
void transcription_filter_remove(void *data, obs_source_t *source);
void transcription_filter_show(void *data);
void transcription_filter_hide(void *data);

const char *const PLUGIN_INFO_TEMPLATE =
"<a href=\"https://github.com/occ-ai/obs-localvocal/\">LocalVocal</a> (%1) by "
Expand Down
18 changes: 12 additions & 6 deletions src/whisper-utils/whisper-model-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
obs_log(LOG_ERROR, "Cannot find Silero VAD model file");
return;
}
obs_log(gf->log_level, "Silero VAD model file: %s", silero_vad_model_file);
std::string silero_vad_model_file_str = std::string(silero_vad_model_file);
bfree(silero_vad_model_file);

if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
is_external_model) {
Expand Down Expand Up @@ -49,14 +52,15 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
obs_log(LOG_WARNING, "Whisper model does not exist");
download_model_with_ui_dialog(
model_info,
[gf, new_model_path, silero_vad_model_file](
[gf, new_model_path, silero_vad_model_file_str](
int download_status, const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO,
"Model download complete");
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(
gf, path, silero_vad_model_file);
gf, path,
silero_vad_model_file_str.c_str());
} else {
obs_log(LOG_ERROR, "Model download failed");
}
Expand All @@ -65,7 +69,7 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
// Model exists, just load it
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf, model_file_found,
silero_vad_model_file);
silero_vad_model_file_str.c_str());
}
} else {
// new model is external file, get file location from file property
Expand All @@ -82,8 +86,9 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
} else {
shutdown_whisper_thread(gf);
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf, external_model_file_path,
silero_vad_model_file);
start_whisper_thread_with_path(
gf, external_model_file_path,
silero_vad_model_file_str.c_str());
}
}
}
Expand All @@ -101,6 +106,7 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
gf->enable_token_ts_dtw, new_dtw_timestamps);
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
shutdown_whisper_thread(gf);
start_whisper_thread_with_path(gf, gf->whisper_model_path, silero_vad_model_file);
start_whisper_thread_with_path(gf, gf->whisper_model_path,
silero_vad_model_file_str.c_str());
}
}
19 changes: 15 additions & 4 deletions src/whisper-utils/whisper-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

#include <obs-module.h>

#ifdef _WIN32
#include <Windows.h>
#endif

void shutdown_whisper_thread(struct transcription_filter_data *gf)
{
obs_log(gf->log_level, "shutdown_whisper_thread");
Expand All @@ -27,7 +31,8 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf,
const std::string &whisper_model_path,
const char *silero_vad_model_file)
{
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", whisper_model_path.c_str());
obs_log(gf->log_level, "start_whisper_thread_with_path: %s, silero model path: %s",
whisper_model_path.c_str(), silero_vad_model_file);
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
if (gf->whisper_context != nullptr) {
obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null");
Expand All @@ -36,16 +41,22 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf,

// initialize Silero VAD
#ifdef _WIN32
std::wstring silero_vad_model_path;
silero_vad_model_path.assign(silero_vad_model_file,
silero_vad_model_file + strlen(silero_vad_model_file));
// convert mbstring to wstring
int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file,
strlen(silero_vad_model_file), NULL, 0);
std::wstring silero_vad_model_path(count, 0);
MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file),
&silero_vad_model_path[0], count);
obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str());
#else
std::string silero_vad_model_path = silero_vad_model_file;
obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str());
#endif
// roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
// for silero vad parameters
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE));

obs_log(gf->log_level, "Create whisper context");
gf->whisper_context = init_whisper_context(whisper_model_path, gf);
if (gf->whisper_context == nullptr) {
obs_log(LOG_ERROR, "Failed to initialize whisper context");
Expand Down

0 comments on commit 2aa151e

Please sign in to comment.