Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: ユーザー辞書データに改行やnull文字が入っていた場合にエラーとする #1522

Merged
merged 11 commits into from
Feb 7, 2025
57 changes: 56 additions & 1 deletion test/unit/user_dict/test_user_dict_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""UserDictWord のテスト"""

from typing import TypedDict
from typing import Literal, TypedDict, get_args

import pytest
from pydantic import ValidationError
Expand Down Expand Up @@ -55,6 +55,46 @@ def test_valid_word() -> None:
UserDictWord(**args)


CsvSafeStrFieldName = Literal[
"part_of_speech",
"part_of_speech_detail_1",
"part_of_speech_detail_2",
"part_of_speech_detail_3",
"inflectional_type",
"inflectional_form",
"stem",
"yomi",
"accent_associative_rule",
]


@pytest.mark.parametrize(
"field",
get_args(CsvSafeStrFieldName),
)
def test_invalid_csv_safe_str(field: CsvSafeStrFieldName) -> None:
"""UserDictWord の文字列 CSV で許可されない文字をエラーとする。"""
# Inputs
test_value_newlines = generate_model()
test_value_newlines[field] = "te\r\nst"
test_value_null = generate_model()
test_value_null[field] = "te\x00st"
test_value_comma = generate_model()
test_value_comma[field] = "te,st"
test_value_double_quote = generate_model()
test_value_double_quote[field] = 'te"st'

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value_newlines)
with pytest.raises(ValidationError):
UserDictWord(**test_value_null)
with pytest.raises(ValidationError):
UserDictWord(**test_value_comma)
with pytest.raises(ValidationError):
UserDictWord(**test_value_double_quote)


def test_convert_to_zenkaku() -> None:
"""UserDictWord は surface を全角にする。"""
# Inputs
Expand Down Expand Up @@ -126,6 +166,21 @@ def test_invalid_pronunciation_not_katakana() -> None:
UserDictWord(**test_value)


def test_invalid_pronunciation_newlines_and_null() -> None:
"""UserDictWord は pronunciation 内の改行や null 文字をエラーとする。"""
# Inputs
test_value_newlines = generate_model()
test_value_newlines["pronunciation"] = "ボイ\r\nボ"
test_value_null = generate_model()
test_value_null["pronunciation"] = "ボイ\x00ボ"

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value_newlines)
with pytest.raises(ValidationError):
UserDictWord(**test_value_null)


def test_invalid_pronunciation_invalid_sutegana() -> None:
"""UserDictWord は無効な pronunciation をエラーとする。"""
# Inputs
Expand Down
1 change: 1 addition & 0 deletions voicevox_engine/app/routers/user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_user_dict_words() -> dict[str, UserDictWord]:
status_code=500, detail="辞書の読み込みに失敗しました。"
)

# TODO: CsvSafeStrを使う
@router.post("/user_dict_word", dependencies=[Depends(verify_mutability)])
def add_user_dict_word(
surface: Annotated[str, Query(description="言葉の表層形")],
Expand Down
120 changes: 73 additions & 47 deletions voicevox_engine/user_dict/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from re import findall, fullmatch
takana-v marked this conversation as resolved.
Show resolved Hide resolved
from typing import Self

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import Annotated


class WordTypes(str, Enum):
Expand All @@ -26,65 +27,90 @@ class WordTypes(str, Enum):
USER_DICT_MAX_PRIORITY = 10


def _check_newlines_and_null(text: str) -> str:
if "\n" in text or "\r" in text:
raise ValueError("ユーザー辞書データ内に改行が含まれています。")
if "\x00" in text:
raise ValueError("ユーザー辞書データ内にnull文字が含まれています。")
return text


def _check_comma_and_double_quote(text: str) -> str:
if "," in text:
raise ValueError("ユーザー辞書データ内にカンマが含まれています。")
if '"' in text:
raise ValueError("ユーザー辞書データ内にダブルクォートが含まれています。")
return text


def _convert_to_zenkaku(surface: str) -> str:
return surface.translate(
str.maketrans(
"".join(chr(0x21 + i) for i in range(94)),
"".join(chr(0xFF01 + i) for i in range(94)),
)
)


def _check_is_katakana(pronunciation: str) -> str:
if not fullmatch(r"[ァ-ヴー]+", pronunciation):
raise ValueError("発音は有効なカタカナでなくてはいけません。")
sutegana = ["ァ", "ィ", "ゥ", "ェ", "ォ", "ャ", "ュ", "ョ", "ヮ", "ッ"]
for i in range(len(pronunciation)):
if pronunciation[i] in sutegana:
# 「キャット」のように、捨て仮名が連続する可能性が考えられるので、
# 「ッ」に関しては「ッ」そのものが連続している場合と、「ッ」の後にほかの捨て仮名が連続する場合のみ無効とする
if i < len(pronunciation) - 1 and (
pronunciation[i + 1] in sutegana[:-1]
or (
pronunciation[i] == sutegana[-1]
and pronunciation[i + 1] == sutegana[-1]
)
):
raise ValueError("無効な発音です。(捨て仮名の連続)")
if pronunciation[i] == "ヮ":
if i != 0 and pronunciation[i - 1] not in ["ク", "グ"]:
raise ValueError("無効な発音です。(「くゎ」「ぐゎ」以外の「ゎ」の使用)")
return pronunciation


CsvSafeStr = Annotated[
str,
AfterValidator(_check_newlines_and_null),
AfterValidator(_check_comma_and_double_quote),
]


class UserDictWord(BaseModel):
"""
辞書のコンパイルに使われる情報
"""

model_config = ConfigDict(validate_assignment=True)

surface: str = Field(description="表層形")
surface: Annotated[
str,
AfterValidator(_convert_to_zenkaku),
AfterValidator(_check_newlines_and_null),
] = Field(description="表層形")
priority: int = Field(
description="優先度", ge=USER_DICT_MIN_PRIORITY, le=USER_DICT_MAX_PRIORITY
)
context_id: int = Field(description="文脈ID", default=1348)
part_of_speech: str = Field(description="品詞")
part_of_speech_detail_1: str = Field(description="品詞細分類1")
part_of_speech_detail_2: str = Field(description="品詞細分類2")
part_of_speech_detail_3: str = Field(description="品詞細分類3")
inflectional_type: str = Field(description="活用型")
inflectional_form: str = Field(description="活用形")
stem: str = Field(description="原形")
yomi: str = Field(description="読み")
pronunciation: str = Field(description="発音")
part_of_speech: CsvSafeStr = Field(description="品詞")
part_of_speech_detail_1: CsvSafeStr = Field(description="品詞細分類1")
part_of_speech_detail_2: CsvSafeStr = Field(description="品詞細分類2")
part_of_speech_detail_3: CsvSafeStr = Field(description="品詞細分類3")
inflectional_type: CsvSafeStr = Field(description="活用型")
inflectional_form: CsvSafeStr = Field(description="活用形")
stem: CsvSafeStr = Field(description="原形")
yomi: CsvSafeStr = Field(description="読み")
pronunciation: Annotated[CsvSafeStr, AfterValidator(_check_is_katakana)] = Field(
description="発音"
)
accent_type: int = Field(description="アクセント型")
mora_count: int | SkipJsonSchema[None] = Field(default=None, description="モーラ数")
accent_associative_rule: str = Field(description="アクセント結合規則")

@field_validator("surface")
@classmethod
def convert_to_zenkaku(cls, surface: str) -> str:
return surface.translate(
str.maketrans(
"".join(chr(0x21 + i) for i in range(94)),
"".join(chr(0xFF01 + i) for i in range(94)),
)
)

@field_validator("pronunciation", mode="before")
@classmethod
def check_is_katakana(cls, pronunciation: str) -> str:
if not fullmatch(r"[ァ-ヴー]+", pronunciation):
raise ValueError("発音は有効なカタカナでなくてはいけません。")
sutegana = ["ァ", "ィ", "ゥ", "ェ", "ォ", "ャ", "ュ", "ョ", "ヮ", "ッ"]
for i in range(len(pronunciation)):
if pronunciation[i] in sutegana:
# 「キャット」のように、捨て仮名が連続する可能性が考えられるので、
# 「ッ」に関しては「ッ」そのものが連続している場合と、「ッ」の後にほかの捨て仮名が連続する場合のみ無効とする
if i < len(pronunciation) - 1 and (
pronunciation[i + 1] in sutegana[:-1]
or (
pronunciation[i] == sutegana[-1]
and pronunciation[i + 1] == sutegana[-1]
)
):
raise ValueError("無効な発音です。(捨て仮名の連続)")
if pronunciation[i] == "ヮ":
if i != 0 and pronunciation[i - 1] not in ["ク", "グ"]:
raise ValueError(
"無効な発音です。(「くゎ」「ぐゎ」以外の「ゎ」の使用)"
)
return pronunciation
accent_associative_rule: CsvSafeStr = Field(description="アクセント結合規則")

@model_validator(mode="after")
def check_mora_count_and_accent_type(self) -> Self:
Expand Down