Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Update language codes in translation module #140

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/tests/localvocal-offline-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "transcription-utils.h"
#include "whisper-utils/whisper-utils.h"
#include "audio-file-utils.h"
#include "translation/language_codes.h"

#include <stdio.h>
#include <stdlib.h>
Expand Down Expand Up @@ -145,7 +146,6 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p
gf->last_sub_render_time = 0;
gf->buffered_output = false;

gf->source_lang = "";
gf->target_lang = "";
gf->translation_ctx.add_context = true;
gf->translation_output = "";
Expand Down Expand Up @@ -266,10 +266,10 @@ void set_text_callback(struct transcription_filter_data *gf,
}

if (gf->translate) {
obs_log(gf->log_level, "Translating text. %s -> %s",
gf->source_lang.c_str(), gf->target_lang.c_str());
obs_log(gf->log_level, "Translating text to %s", gf->target_lang.c_str());
std::string translated_text;
if (translate(gf->translation_ctx, str_copy, gf->source_lang,
if (translate(gf->translation_ctx, str_copy,
language_codes_from_whisper[gf->whisper_params.language],
gf->target_lang,
translated_text) == OBS_POLYGLOT_TRANSLATION_SUCCESS) {
if (gf->log_words) {
Expand Down Expand Up @@ -365,7 +365,6 @@ int wmain(int argc, wchar_t *argv[])
"Source or target translation language are empty or disabled");
} else {
obs_log(LOG_INFO, "Setting translation languages");
gf->source_lang = sourceLanguageStr;
gf->target_lang = targetLanguageStr;
build_and_enable_translation(gf, ct2ModelFolderStr.c_str());
}
Expand Down
12 changes: 8 additions & 4 deletions src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "translation/translation-includes.h"
#include "whisper-utils/whisper-utils.h"
#include "whisper-utils/whisper-model-utils.h"
#include "translation/language_codes.h"

void send_caption_to_source(const std::string &target_source_name, const std::string &caption,
struct transcription_filter_data *gf)
Expand Down Expand Up @@ -49,15 +50,17 @@ void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm
}

std::string send_sentence_to_translation(const std::string &sentence,
struct transcription_filter_data *gf)
struct transcription_filter_data *gf,
const std::string &source_language)
{
const std::string last_text = gf->last_text;
gf->last_text = sentence;
if (gf->translate && !sentence.empty() && sentence != last_text) {
obs_log(gf->log_level, "Translating text. %s -> %s", gf->source_lang.c_str(),
obs_log(gf->log_level, "Translating text. %s -> %s", source_language.c_str(),
gf->target_lang.c_str());
std::string translated_text;
if (translate(gf->translation_ctx, sentence, gf->source_lang, gf->target_lang,
if (translate(gf->translation_ctx, sentence,
language_codes_from_whisper[source_language], gf->target_lang,
translated_text) == OBS_POLYGLOT_TRANSLATION_SUCCESS) {
if (gf->log_words) {
obs_log(LOG_INFO, "Translation: '%s' -> '%s'", sentence.c_str(),
Expand Down Expand Up @@ -219,7 +222,8 @@ void set_text_callback(struct transcription_filter_data *gf,
}

// send the sentence to translation (if enabled)
std::string translated_sentence = send_sentence_to_translation(str_copy, gf);
std::string translated_sentence =
send_sentence_to_translation(str_copy, gf, result.language);

if (gf->translate) {
if (gf->translation_output == "none") {
Expand Down
1 change: 0 additions & 1 deletion src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ struct transcription_filter_data {
bool process_while_muted = false;
bool rename_file_to_match_recording = false;
bool translate = false;
std::string source_lang;
std::string target_lang;
std::string translation_output;
bool enable_token_ts_dtw = false;
Expand Down
18 changes: 4 additions & 14 deletions src/transcription-filter-properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ bool translation_options_callback(obs_properties_t *props, obs_property_t *prope
obs_property_set_visible(obs_properties_get(props, prop), translate_enabled);
}
for (const auto &prop :
{"translate_source_language", "translate_add_context",
"translate_input_tokenization_style", "translation_sampling_temperature",
"translation_repetition_penalty", "translation_beam_size",
"translation_max_decoding_length", "translation_no_repeat_ngram_size",
"translation_max_input_length"}) {
{"translate_add_context", "translate_input_tokenization_style",
"translation_sampling_temperature", "translation_repetition_penalty",
"translation_beam_size", "translation_max_decoding_length",
"translation_no_repeat_ngram_size", "translation_max_input_length"}) {
obs_property_set_visible(obs_properties_get(props, prop),
translate_enabled && is_advanced);
}
Expand Down Expand Up @@ -130,8 +129,6 @@ bool translation_external_model_selection(obs_properties_t *props, obs_property_
const bool is_advanced = obs_data_get_int(settings, "advanced_settings_mode") == 1;
obs_property_set_visible(obs_properties_get(props, "translation_model_path_external"),
is_external);
obs_property_set_visible(obs_properties_get(props, "translate_source_language"),
!is_whisper && is_advanced);
obs_property_set_visible(obs_properties_get(props, "translate_add_context"),
!is_whisper && is_advanced);
obs_property_set_visible(obs_properties_get(props, "translate_input_tokenization_style"),
Expand Down Expand Up @@ -214,27 +211,20 @@ void add_translation_group_properties(obs_properties_t *ppts)
obs_property_t *prop_tgt = obs_properties_add_list(
translation_group, "translate_target_language", MT_("target_language"),
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING);
obs_property_t *prop_src = obs_properties_add_list(
translation_group, "translate_source_language", MT_("source_language"),
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING);
obs_properties_add_bool(translation_group, "translate_add_context",
MT_("translate_add_context"));

// Populate the dropdown with the language codes
for (const auto &language : language_codes) {
obs_property_list_add_string(prop_tgt, language.second.c_str(),
language.first.c_str());
obs_property_list_add_string(prop_src, language.second.c_str(),
language.first.c_str());
}
// add option for routing the translation to an output source
obs_property_t *prop_output = obs_properties_add_list(translation_group, "translate_output",
MT_("translate_output"),
OBS_COMBO_TYPE_LIST,
OBS_COMBO_FORMAT_STRING);
obs_property_list_add_string(prop_output, "Write to captions output", "none");
// TODO add file output option
// obs_property_list_add_string(...
obs_enum_sources(add_sources_to_list, prop_output);

// add callback to enable/disable translation group
Expand Down
6 changes: 2 additions & 4 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ void transcription_filter_update(void *data, obs_data_t *s)
}

bool new_translate = obs_data_get_bool(s, "translate");
gf->source_lang = obs_data_get_string(s, "translate_source_language");
gf->target_lang = obs_data_get_string(s, "translate_target_language");
gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context");
gf->translation_ctx.input_tokenization_style =
Expand Down Expand Up @@ -358,9 +357,9 @@ void transcription_filter_update(void *data, obs_data_t *s)
: "auto";
} else {
// take the language from gf->target_lang
if (language_codes_2_reverse.count(gf->target_lang) > 0) {
if (language_codes_to_whisper.count(gf->target_lang) > 0) {
gf->whisper_params.language =
language_codes_2_reverse[gf->target_lang].c_str();
language_codes_to_whisper[gf->target_lang].c_str();
} else {
gf->whisper_params.language = "auto";
}
Expand Down Expand Up @@ -580,7 +579,6 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_bool(s, "advanced_settings", false);
obs_data_set_default_bool(s, "translate", false);
obs_data_set_default_string(s, "translate_target_language", "__es__");
obs_data_set_default_string(s, "translate_source_language", "__en__");
obs_data_set_default_bool(s, "translate_add_context", true);
obs_data_set_default_string(s, "translate_model", "whisper-based-translation");
obs_data_set_default_string(s, "translation_model_path_external", "");
Expand Down
4 changes: 2 additions & 2 deletions src/translation/language_codes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ std::map<std::string, std::string> language_codes_reverse = {{"Afrikaans", "__af
{"Chinese", "__zh__"},
{"Zulu", "__zu__"}};

std::map<std::string, std::string> language_codes_2 = {
std::map<std::string, std::string> language_codes_from_whisper = {
{"af", "__af__"}, {"am", "__am__"}, {"ar", "__ar__"}, {"ast", "__ast__"},
{"az", "__az__"}, {"ba", "__ba__"}, {"be", "__be__"}, {"bg", "__bg__"},
{"bn", "__bn__"}, {"br", "__br__"}, {"bs", "__bs__"}, {"ca", "__ca__"},
Expand All @@ -228,7 +228,7 @@ std::map<std::string, std::string> language_codes_2 = {
{"uz", "__uz__"}, {"vi", "__vi__"}, {"wo", "__wo__"}, {"xh", "__xh__"},
{"yi", "__yi__"}, {"yo", "__yo__"}, {"zh", "__zh__"}, {"zu", "__zu__"}};

std::map<std::string, std::string> language_codes_2_reverse = {
std::map<std::string, std::string> language_codes_to_whisper = {
{"__af__", "af"}, {"__am__", "am"}, {"__ar__", "ar"}, {"__ast__", "ast"},
{"__az__", "az"}, {"__ba__", "ba"}, {"__be__", "be"}, {"__bg__", "bg"},
{"__bn__", "bn"}, {"__br__", "br"}, {"__bs__", "bs"}, {"__ca__", "ca"},
Expand Down
4 changes: 2 additions & 2 deletions src/translation/language_codes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

extern std::map<std::string, std::string> language_codes;
extern std::map<std::string, std::string> language_codes_reverse;
extern std::map<std::string, std::string> language_codes_2;
extern std::map<std::string, std::string> language_codes_2_reverse;
extern std::map<std::string, std::string> language_codes_from_whisper;
extern std::map<std::string, std::string> language_codes_to_whisper;

#endif // LANGUAGE_CODES_H
2 changes: 1 addition & 1 deletion src/translation/translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ int translate(struct translation_context &translation_ctx, const std::string &te
// set input tokens
std::vector<std::string> input_tokens = {};
std::vector<std::string> new_input_tokens = translation_ctx.tokenizer(
"<2" + language_codes_2_reverse[target_lang] + "> " + text);
"<2" + language_codes_to_whisper[target_lang] + "> " + text);
input_tokens.insert(input_tokens.end(), new_input_tokens.begin(),
new_input_tokens.end());
const std::vector<std::vector<std::string>> batch = {input_tokens};
Expand Down
2 changes: 1 addition & 1 deletion src/whisper-utils/token-buffer-thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ void TokenBufferThread::monitor()
contribution.end());
#endif

obs_log(LOG_INFO, "TokenBufferThread::monitor: output '%s'",
obs_log(gf->log_level, "TokenBufferThread::monitor: output '%s'",
contribution_out.c_str());
this->sentenceOutputCallback(contribution_out);
lastContributionIsSent = true;
Expand Down
33 changes: 23 additions & 10 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
{
if (gf == nullptr) {
obs_log(LOG_ERROR, "run_whisper_inference: gf is null");
return {DETECTION_RESULT_UNKNOWN, "", start_offset_ms, end_offset_ms, {}};
return {DETECTION_RESULT_UNKNOWN, "", start_offset_ms, end_offset_ms, {}, ""};
}

if (pcm32f_data_ == nullptr || pcm32f_num_samples == 0) {
obs_log(LOG_ERROR, "run_whisper_inference: pcm32f_data is null or size is 0");
return {DETECTION_RESULT_UNKNOWN, "", start_offset_ms, end_offset_ms, {}};
return {DETECTION_RESULT_UNKNOWN, "", start_offset_ms, end_offset_ms, {}, ""};
}

obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__,
Expand Down Expand Up @@ -169,7 +169,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING, "whisper context is null");
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}};
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""};
}

// run the inference
Expand All @@ -185,15 +185,23 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
if (should_free_buffer) {
bfree(pcm32f_data);
}
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}};
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""};
}
if (should_free_buffer) {
bfree(pcm32f_data);
}

std::string language = gf->whisper_params.language;
if (gf->whisper_params.language == nullptr || strlen(gf->whisper_params.language) == 0 ||
strcmp(gf->whisper_params.language, "auto") == 0) {
int lang_id = whisper_lang_auto_detect(gf->whisper_context, 0, 1, nullptr);
language = whisper_lang_str(lang_id);
obs_log(gf->log_level, "Detected language: %s", language.c_str());
}

if (whisper_full_result != 0) {
obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result);
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}};
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""};
} else {
float sentence_p = 0.0f;
std::string text = "";
Expand Down Expand Up @@ -235,7 +243,12 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
// ratio is too high, skip this detection
obs_log(gf->log_level,
"Time token ratio too high, skipping");
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}};
return {DETECTION_RESULT_SILENCE,
"",
t0,
t1,
{},
language};
}
keep = false;
}
Expand All @@ -253,7 +266,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
if (sentence_p < gf->sentence_psum_accept_thresh) {
obs_log(gf->log_level, "Sentence psum %.3f below threshold %.3f, skipping",
sentence_p, gf->sentence_psum_accept_thresh);
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}};
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language};
}

obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str());
Expand All @@ -264,10 +277,10 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
}

if (text.empty() || text == "." || text == " " || text == "\n") {
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}};
return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language};
}

return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens};
return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens, language};
}
}

Expand Down Expand Up @@ -509,7 +522,7 @@ void whisper_loop(void *data)
uint64_t now = now_ms();
if ((now - gf->last_sub_render_time) > gf->min_sub_duration) {
// clear the current sub, call the callback with an empty string
obs_log(LOG_INFO,
obs_log(gf->log_level,
"Clearing current subtitle. now: %lu ms, last: %lu ms", now,
gf->last_sub_render_time);
set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}});
Expand Down
1 change: 1 addition & 0 deletions src/whisper-utils/whisper-processing.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct DetectionResultWithText {
uint64_t start_timestamp_ms;
uint64_t end_timestamp_ms;
std::vector<whisper_token_data> tokens;
std::string language;
};

enum VadState { VAD_STATE_WAS_ON = 0, VAD_STATE_WAS_OFF, VAD_STATE_IS_OFF };
Expand Down
Loading