diff --git a/test/unit/user_dict/test_user_dict_model.py b/test/unit/user_dict/test_user_dict_model.py index 5c9b4753a..fd8b1f3c0 100644 --- a/test/unit/user_dict/test_user_dict_model.py +++ b/test/unit/user_dict/test_user_dict_model.py @@ -1,6 +1,6 @@ """UserDictWord のテスト""" -from typing import TypedDict +from typing import Literal, TypedDict, get_args import pytest from pydantic import ValidationError @@ -55,6 +55,46 @@ def test_valid_word() -> None: UserDictWord(**args) +CsvSafeStrFieldName = Literal[ + "part_of_speech", + "part_of_speech_detail_1", + "part_of_speech_detail_2", + "part_of_speech_detail_3", + "inflectional_type", + "inflectional_form", + "stem", + "yomi", + "accent_associative_rule", +] + + +@pytest.mark.parametrize( + "field", + get_args(CsvSafeStrFieldName), +) +def test_invalid_csv_safe_str(field: CsvSafeStrFieldName) -> None: + """UserDictWord の文字列 CSV で許可されない文字をエラーとする。""" + # Inputs + test_value_newlines = generate_model() + test_value_newlines[field] = "te\r\nst" + test_value_null = generate_model() + test_value_null[field] = "te\x00st" + test_value_comma = generate_model() + test_value_comma[field] = "te,st" + test_value_double_quote = generate_model() + test_value_double_quote[field] = 'te"st' + + # Test + with pytest.raises(ValidationError): + UserDictWord(**test_value_newlines) + with pytest.raises(ValidationError): + UserDictWord(**test_value_null) + with pytest.raises(ValidationError): + UserDictWord(**test_value_comma) + with pytest.raises(ValidationError): + UserDictWord(**test_value_double_quote) + + def test_convert_to_zenkaku() -> None: """UserDictWord は surface を全角にする。""" # Inputs @@ -126,6 +166,21 @@ def test_invalid_pronunciation_not_katakana() -> None: UserDictWord(**test_value) +def test_invalid_pronunciation_newlines_and_null() -> None: + """UserDictWord は pronunciation 内の改行や null 文字をエラーとする。""" + # Inputs + test_value_newlines = generate_model() + test_value_newlines["pronunciation"] = "ボイ\r\nボ" + test_value_null = generate_model() + test_value_null["pronunciation"] = "ボイ\x00ボ" + + # Test + with pytest.raises(ValidationError): + UserDictWord(**test_value_newlines) + with pytest.raises(ValidationError): + UserDictWord(**test_value_null) + + def test_invalid_pronunciation_invalid_sutegana() -> None: """UserDictWord は無効な pronunciation をエラーとする。""" # Inputs diff --git a/voicevox_engine/app/routers/user_dict.py b/voicevox_engine/app/routers/user_dict.py index 10b40c69d..6141b75a7 100644 --- a/voicevox_engine/app/routers/user_dict.py +++ b/voicevox_engine/app/routers/user_dict.py @@ -42,6 +42,7 @@ def get_user_dict_words() -> dict[str, UserDictWord]: status_code=500, detail="辞書の読み込みに失敗しました。" ) + # TODO: CsvSafeStrを使う @router.post("/user_dict_word", dependencies=[Depends(verify_mutability)]) def add_user_dict_word( surface: Annotated[str, Query(description="言葉の表層形")], diff --git a/voicevox_engine/user_dict/model.py b/voicevox_engine/user_dict/model.py index 0cc0dd88b..8ae1c1e2d 100644 --- a/voicevox_engine/user_dict/model.py +++ b/voicevox_engine/user_dict/model.py @@ -8,8 +8,9 @@ from re import findall, fullmatch from typing import Self -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator from pydantic.json_schema import SkipJsonSchema +from typing_extensions import Annotated class WordTypes(str, Enum): @@ -26,6 +27,60 @@ class WordTypes(str, Enum): USER_DICT_MAX_PRIORITY = 10 +def _check_newlines_and_null(text: str) -> str: + if "\n" in text or "\r" in text: + raise ValueError("ユーザー辞書データ内に改行が含まれています。") + if "\x00" in text: + raise ValueError("ユーザー辞書データ内にnull文字が含まれています。") + return text + + +def _check_comma_and_double_quote(text: str) -> str: + if "," in text: + raise ValueError("ユーザー辞書データ内にカンマが含まれています。") + if '"' in text: + raise ValueError("ユーザー辞書データ内にダブルクォートが含まれています。") + return text + + +def _convert_to_zenkaku(surface: str) -> str: + return surface.translate( + str.maketrans( + "".join(chr(0x21 + i) for i in range(94)), + "".join(chr(0xFF01 + i) for i in range(94)), + ) + ) + + +def _check_is_katakana(pronunciation: str) -> str: + if not fullmatch(r"[ァ-ヴー]+", pronunciation): + raise ValueError("発音は有効なカタカナでなくてはいけません。") + sutegana = ["ァ", "ィ", "ゥ", "ェ", "ォ", "ャ", "ュ", "ョ", "ヮ", "ッ"] + for i in range(len(pronunciation)): + if pronunciation[i] in sutegana: + # 「キャット」のように、捨て仮名が連続する可能性が考えられるので、 + # 「ッ」に関しては「ッ」そのものが連続している場合と、「ッ」の後にほかの捨て仮名が連続する場合のみ無効とする + if i < len(pronunciation) - 1 and ( + pronunciation[i + 1] in sutegana[:-1] + or ( + pronunciation[i] == sutegana[-1] + and pronunciation[i + 1] == sutegana[-1] + ) + ): + raise ValueError("無効な発音です。(捨て仮名の連続)") + if pronunciation[i] == "ヮ": + if i != 0 and pronunciation[i - 1] not in ["ク", "グ"]: + raise ValueError("無効な発音です。(「くゎ」「ぐゎ」以外の「ゎ」の使用)") + return pronunciation + + +CsvSafeStr = Annotated[ + str, + AfterValidator(_check_newlines_and_null), + AfterValidator(_check_comma_and_double_quote), +] + + class UserDictWord(BaseModel): """ 辞書のコンパイルに使われる情報 @@ -33,58 +88,29 @@ class UserDictWord(BaseModel): model_config = ConfigDict(validate_assignment=True) - surface: str = Field(description="表層形") + surface: Annotated[ + str, + AfterValidator(_convert_to_zenkaku), + AfterValidator(_check_newlines_and_null), + ] = Field(description="表層形") priority: int = Field( description="優先度", ge=USER_DICT_MIN_PRIORITY, le=USER_DICT_MAX_PRIORITY ) context_id: int = Field(description="文脈ID", default=1348) - part_of_speech: str = Field(description="品詞") - part_of_speech_detail_1: str = Field(description="品詞細分類1") - part_of_speech_detail_2: str = Field(description="品詞細分類2") - part_of_speech_detail_3: str = Field(description="品詞細分類3") - inflectional_type: str = Field(description="活用型") - inflectional_form: str = Field(description="活用形") - stem: str = Field(description="原形") - yomi: str = Field(description="読み") - pronunciation: str = Field(description="発音") + part_of_speech: CsvSafeStr = Field(description="品詞") + part_of_speech_detail_1: CsvSafeStr = Field(description="品詞細分類1") + part_of_speech_detail_2: CsvSafeStr = Field(description="品詞細分類2") + part_of_speech_detail_3: CsvSafeStr = Field(description="品詞細分類3") + inflectional_type: CsvSafeStr = Field(description="活用型") + inflectional_form: CsvSafeStr = Field(description="活用形") + stem: CsvSafeStr = Field(description="原形") + yomi: CsvSafeStr = Field(description="読み") + pronunciation: Annotated[CsvSafeStr, AfterValidator(_check_is_katakana)] = Field( + description="発音" + ) accent_type: int = Field(description="アクセント型") mora_count: int | SkipJsonSchema[None] = Field(default=None, description="モーラ数") - accent_associative_rule: str = Field(description="アクセント結合規則") - - @field_validator("surface") - @classmethod - def convert_to_zenkaku(cls, surface: str) -> str: - return surface.translate( - str.maketrans( - "".join(chr(0x21 + i) for i in range(94)), - "".join(chr(0xFF01 + i) for i in range(94)), - ) - ) - - @field_validator("pronunciation", mode="before") - @classmethod - def check_is_katakana(cls, pronunciation: str) -> str: - if not fullmatch(r"[ァ-ヴー]+", pronunciation): - raise ValueError("発音は有効なカタカナでなくてはいけません。") - sutegana = ["ァ", "ィ", "ゥ", "ェ", "ォ", "ャ", "ュ", "ョ", "ヮ", "ッ"] - for i in range(len(pronunciation)): - if pronunciation[i] in sutegana: - # 「キャット」のように、捨て仮名が連続する可能性が考えられるので、 - # 「ッ」に関しては「ッ」そのものが連続している場合と、「ッ」の後にほかの捨て仮名が連続する場合のみ無効とする - if i < len(pronunciation) - 1 and ( - pronunciation[i + 1] in sutegana[:-1] - or ( - pronunciation[i] == sutegana[-1] - and pronunciation[i + 1] == sutegana[-1] - ) - ): - raise ValueError("無効な発音です。(捨て仮名の連続)") - if pronunciation[i] == "ヮ": - if i != 0 and pronunciation[i - 1] not in ["ク", "グ"]: - raise ValueError( - "無効な発音です。(「くゎ」「ぐゎ」以外の「ゎ」の使用)" - ) - return pronunciation + accent_associative_rule: CsvSafeStr = Field(description="アクセント結合規則") @model_validator(mode="after") def check_mora_count_and_accent_type(self) -> Self: