Skip to content

Commit

Permalink
refactor: Update create_obs_text_source function to create source onl…
Browse files Browse the repository at this point in the history
…y if it doesn't exist (#126)
  • Loading branch information
royshil authored Jul 9, 2024
1 parent ee07bbe commit 234a938
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 27 deletions.
16 changes: 11 additions & 5 deletions src/transcription-filter-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@
#include <obs.h>
#include <obs-frontend-api.h>

void create_obs_text_source()
void create_obs_text_source_if_needed()
{
// check if a source called "LocalVocal Subtitles" exists
obs_source_t *source = obs_get_source_by_name("LocalVocal Subtitles");
if (source) {
// source already exists, release it
obs_source_release(source);
return;
}

// create a new OBS text source called "LocalVocal Subtitles"
obs_source_t *scene_as_source = obs_frontend_get_current_scene();
obs_scene_t *scene = obs_scene_from_source(scene_as_source);
#ifdef _WIN32
obs_source_t *source =
obs_source_create("text_gdiplus_v2", "LocalVocal Subtitles", nullptr, nullptr);
source = obs_source_create("text_gdiplus_v2", "LocalVocal Subtitles", nullptr, nullptr);
#else
obs_source_t *source =
obs_source_create("text_ft2_source_v2", "LocalVocal Subtitles", nullptr, nullptr);
source = obs_source_create("text_ft2_source_v2", "LocalVocal Subtitles", nullptr, nullptr);
#endif
if (source) {
// add source to the current scene
Expand Down
2 changes: 1 addition & 1 deletion src/transcription-filter-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ inline enum speaker_layout convert_speaker_layout(uint8_t channels)
}
}

void create_obs_text_source();
void create_obs_text_source_if_needed();

bool add_sources_to_list(void *list_property, obs_source_t *source);

Expand Down
75 changes: 56 additions & 19 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,15 @@ void transcription_filter_update(void *data, obs_data_t *s)
int new_buffer_num_chars_per_line = (int)obs_data_get_int(s, "buffer_num_chars_per_line");
TokenBufferSegmentation new_buffer_output_type =
(TokenBufferSegmentation)obs_data_get_int(s, "buffer_output_type");
gf->filter_words_replace =
deserialize_filter_words_replace(obs_data_get_string(s, "filter_words_replace"));
const char *filter_words_replace = obs_data_get_string(s, "filter_words_replace");
if (filter_words_replace != nullptr && strlen(filter_words_replace) > 0) {
obs_log(gf->log_level, "filter_words_replace: %s", filter_words_replace);
// deserialize the filter words replace
gf->filter_words_replace = deserialize_filter_words_replace(filter_words_replace);
} else {
// clear the filter words replace
gf->filter_words_replace.clear();
}

if (gf->save_to_file) {
gf->output_file_path = "";
Expand Down Expand Up @@ -261,10 +268,23 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context");
gf->translation_ctx.input_tokenization_style =
(InputTokenizationStyle)obs_data_get_int(s, "translate_input_tokenization_style");
gf->translation_output = obs_data_get_string(s, "translate_output");
std::string new_translate_model_index = obs_data_get_string(s, "translate_model");
const char *translate_output_cstr = obs_data_get_string(s, "translate_output");
gf->translation_output =
(translate_output_cstr != nullptr && strlen(translate_output_cstr) > 0)
? translate_output_cstr
: "";
const char *translate_model_path_cstr = obs_data_get_string(s, "translate_model_path");
std::string new_translate_model_index =
(translate_model_path_cstr != nullptr && strlen(translate_model_path_cstr) > 0)
? translate_model_path_cstr
: "";
const char *translate_model_path_external_cstr =
obs_data_get_string(s, "translate_model_path_external");
std::string new_translation_model_path_external =
obs_data_get_string(s, "translation_model_path_external");
(translate_model_path_external_cstr != nullptr &&
strlen(translate_model_path_external_cstr) > 0)
? translate_model_path_external_cstr
: "";

if (new_translate != gf->translate ||
new_translate_model_index != gf->translation_model_index ||
Expand Down Expand Up @@ -325,8 +345,12 @@ void transcription_filter_update(void *data, obs_data_t *s)
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
gf->whisper_params.language =
const char *whisper_language_select =
obs_data_get_string(s, "whisper_language_select");
gf->whisper_params.language = (whisper_language_select != nullptr &&
strlen(whisper_language_select) > 0)
? whisper_language_select
: "auto";
} else {
// take the language from gf->target_lang
if (language_codes_2_reverse.count(gf->target_lang) > 0) {
Expand All @@ -336,7 +360,10 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->whisper_params.language = "auto";
}
}
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.initial_prompt =
obs_data_get_string(s, "initial_prompt") != nullptr
? obs_data_get_string(s, "initial_prompt")
: "";
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate");
Expand Down Expand Up @@ -377,7 +404,9 @@ void transcription_filter_update(void *data, obs_data_t *s)
} else {
// check if the whisper model selection has changed
const std::string new_model_path =
obs_data_get_string(s, "whisper_model_path");
obs_data_get_string(s, "whisper_model_path") != nullptr
? obs_data_get_string(s, "whisper_model_path")
: "Whisper Tiny English (74Mb)";
if (gf->whisper_model_path != new_model_path) {
obs_log(LOG_INFO, "New model selected: %s", new_model_path.c_str());
update_whisper_model(gf);
Expand Down Expand Up @@ -418,6 +447,11 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
// allocate copy buffers
gf->copy_buffers[0] =
static_cast<float *>(bzalloc(gf->channels * gf->frames * sizeof(float)));
if (gf->copy_buffers[0] == nullptr) {
obs_log(LOG_ERROR, "Failed to allocate copy buffer");
gf->active = false;
return nullptr;
}
for (size_t c = 1; c < gf->channels; c++) { // set the channel pointers
gf->copy_buffers[c] = gf->copy_buffers[0] + c * gf->frames;
}
Expand All @@ -439,21 +473,18 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
dst.speakers = convert_speaker_layout((uint8_t)1);

gf->resampler_to_whisper = audio_resampler_create(&dst, &src);
if (!gf->resampler_to_whisper) {
obs_log(LOG_ERROR, "Failed to create resampler");
gf->active = false;
return nullptr;
}

obs_log(gf->log_level, "clear text source data");
const char *subtitle_sources = obs_data_get_string(settings, "subtitle_sources");
if (subtitle_sources == nullptr || strcmp(subtitle_sources, "none") == 0 ||
strcmp(subtitle_sources, "(null)") == 0 || strlen(subtitle_sources) == 0) {
if (subtitle_sources == nullptr || strlen(subtitle_sources) == 0 ||
strcmp(subtitle_sources, "none") == 0 || strcmp(subtitle_sources, "(null)") == 0) {
obs_log(gf->log_level, "Create text source");
// check if a source called "LocalVocal Subtitles" exists
obs_source_t *source = obs_get_source_by_name("LocalVocal Subtitles");
if (source) {
// source exists, release it
obs_source_release(source);
} else {
// create a new OBS text source called "LocalVocal Subtitles"
create_obs_text_source();
}
create_obs_text_source_if_needed();
gf->text_source_name = "LocalVocal Subtitles";
obs_data_set_string(settings, "subtitle_sources", "LocalVocal Subtitles");
} else {
Expand All @@ -467,6 +498,12 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
gf->whisper_context = nullptr;

signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context);
if (sh_filter == nullptr) {
obs_log(LOG_ERROR, "Failed to get signal handler");
gf->active = false;
return nullptr;
}

signal_handler_connect(sh_filter, "enable", enable_callback, gf);

obs_log(gf->log_level, "run update");
Expand Down
17 changes: 15 additions & 2 deletions src/whisper-utils/whisper-model-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,29 @@ void update_whisper_model(struct transcription_filter_data *gf)
}

// Get settings from context
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
std::string new_model_path = obs_data_get_string(s, "whisper_model_path") != nullptr
? obs_data_get_string(s, "whisper_model_path")
: "";
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
obs_data_get_string(s, "whisper_model_path_external") != nullptr
? obs_data_get_string(s, "whisper_model_path_external")
: "";
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
obs_data_release(s);

// update the whisper model path

const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;

if (!is_external_model && new_model_path.empty()) {
obs_log(LOG_WARNING, "Whisper model path is empty");
return;
}
if (is_external_model && external_model_file_path.empty()) {
obs_log(LOG_WARNING, "External model file path is empty");
return;
}

char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx");
if (silero_vad_model_file == nullptr) {
obs_log(LOG_ERROR, "Cannot find Silero VAD model file");
Expand Down

0 comments on commit 234a938

Please sign in to comment.