diff --git a/src/translation/translation.cpp b/src/translation/translation.cpp index 205a09d..e11f072 100644 --- a/src/translation/translation.cpp +++ b/src/translation/translation.cpp @@ -31,7 +31,9 @@ int build_translation_context(struct translation_context &translation_ctx) obs_log(LOG_INFO, "Building translation context from '%s'...", local_model_path.c_str()); // find the SPM file in the model folder std::string local_spm_path = find_file_in_folder_by_regex_expression( - local_model_path, "(sentencepiece|spm|spiece).*?\\.model"); + local_model_path, "(sentencepiece|spm|spiece|source).*?\\.(model|spm)"); + std::string target_spm_path = + find_file_in_folder_by_regex_expression(local_model_path, "target.*?\\.spm"); try { obs_log(LOG_INFO, "Loading SPM from %s", local_spm_path.c_str()); @@ -42,6 +44,22 @@ int build_translation_context(struct translation_context &translation_ctx) return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; } + if (!target_spm_path.empty()) { + obs_log(LOG_INFO, "Loading target SPM from %s", target_spm_path.c_str()); + translation_ctx.target_processor.reset( + new sentencepiece::SentencePieceProcessor()); + const auto target_status = + translation_ctx.target_processor->Load(target_spm_path); + if (!target_status.ok()) { + obs_log(LOG_ERROR, "Failed to load target SPM: %s", + target_status.ToString().c_str()); + return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + } + } else { + obs_log(LOG_INFO, "Target SPM not found, using source SPM for target"); + translation_ctx.target_processor.release(); + } + translation_ctx.tokenizer = [&translation_ctx](const std::string &text) { std::vector tokens; translation_ctx.processor->Encode(text, &tokens); @@ -50,7 +68,11 @@ int build_translation_context(struct translation_context &translation_ctx) translation_ctx.detokenizer = [&translation_ctx](const std::vector &tokens) { std::string text; - translation_ctx.processor->Decode(tokens, &text); + if (translation_ctx.target_processor) { + translation_ctx.target_processor->Decode(tokens, &text); + } else { + translation_ctx.processor->Decode(tokens, &text); + } return std::regex_replace(text, std::regex(""), "UNK"); }; diff --git a/src/translation/translation.h b/src/translation/translation.h index 1b601fc..0d45080 100644 --- a/src/translation/translation.h +++ b/src/translation/translation.h @@ -20,6 +20,7 @@ class SentencePieceProcessor; struct translation_context { std::string local_model_folder_path; std::unique_ptr processor; + std::unique_ptr target_processor; std::unique_ptr translator; std::unique_ptr options; std::function(const std::string &)> tokenizer;