From 05fb7485ce4f3096a37189bab4119dbdef2c16cc Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Fri, 2 Aug 2024 21:00:58 -0400 Subject: [PATCH 1/2] refactor: Add target SPM loading and decoding logic in translation module --- src/translation/translation.cpp | 25 +++++++++++++++++++++++-- src/translation/translation.h | 1 + 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/translation/translation.cpp b/src/translation/translation.cpp index 205a09d..c1ef7b7 100644 --- a/src/translation/translation.cpp +++ b/src/translation/translation.cpp @@ -31,7 +31,9 @@ int build_translation_context(struct translation_context &translation_ctx) obs_log(LOG_INFO, "Building translation context from '%s'...", local_model_path.c_str()); // find the SPM file in the model folder std::string local_spm_path = find_file_in_folder_by_regex_expression( - local_model_path, "(sentencepiece|spm|spiece).*?\\.model"); + local_model_path, "(sentencepiece|spm|spiece|source).*?\\.(model|spm)"); + std::string target_spm_path = + find_file_in_folder_by_regex_expression(local_model_path, "target.*?\\.spm"); try { obs_log(LOG_INFO, "Loading SPM from %s", local_spm_path.c_str()); @@ -42,6 +44,21 @@ int build_translation_context(struct translation_context &translation_ctx) return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; } + if (!target_spm_path.empty()) { + obs_log(LOG_INFO, "Loading target SPM from %s", target_spm_path.c_str()); + translation_ctx.target_processor.reset( + new sentencepiece::SentencePieceProcessor()); + const auto status = translation_ctx.target_processor->Load(target_spm_path); + if (!status.ok()) { + obs_log(LOG_ERROR, "Failed to load target SPM: %s", + status.ToString().c_str()); + return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + } + } else { + obs_log(LOG_INFO, "Target SPM not found, using source SPM for target"); + translation_ctx.target_processor.release(); + } + translation_ctx.tokenizer = [&translation_ctx](const std::string &text) { std::vector tokens; translation_ctx.processor->Encode(text, &tokens); @@ -50,7 +67,11 @@ int build_translation_context(struct translation_context &translation_ctx) translation_ctx.detokenizer = [&translation_ctx](const std::vector &tokens) { std::string text; - translation_ctx.processor->Decode(tokens, &text); + if (translation_ctx.target_processor) { + translation_ctx.target_processor->Decode(tokens, &text); + } else { + translation_ctx.processor->Decode(tokens, &text); + } return std::regex_replace(text, std::regex(""), "UNK"); }; diff --git a/src/translation/translation.h b/src/translation/translation.h index 1b601fc..0d45080 100644 --- a/src/translation/translation.h +++ b/src/translation/translation.h @@ -20,6 +20,7 @@ class SentencePieceProcessor; struct translation_context { std::string local_model_folder_path; std::unique_ptr processor; + std::unique_ptr target_processor; std::unique_ptr translator; std::unique_ptr options; std::function(const std::string &)> tokenizer; From fadcac4b5cfa3d0439e228320b08297bc6ca1263 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Fri, 2 Aug 2024 22:20:54 -0400 Subject: [PATCH 2/2] refactor: Update target SPM loading error handling in translation module --- src/translation/translation.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/translation/translation.cpp b/src/translation/translation.cpp index c1ef7b7..e11f072 100644 --- a/src/translation/translation.cpp +++ b/src/translation/translation.cpp @@ -48,10 +48,11 @@ int build_translation_context(struct translation_context &translation_ctx) obs_log(LOG_INFO, "Loading target SPM from %s", target_spm_path.c_str()); translation_ctx.target_processor.reset( new sentencepiece::SentencePieceProcessor()); - const auto status = translation_ctx.target_processor->Load(target_spm_path); - if (!status.ok()) { + const auto target_status = + translation_ctx.target_processor->Load(target_spm_path); + if (!target_status.ok()) { obs_log(LOG_ERROR, "Failed to load target SPM: %s", - status.ToString().c_str()); + target_status.ToString().c_str()); return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; } } else {