Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: SupportedDevicesからデシアライズ機能を剥奪 #958

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions crates/voicevox_core/src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
};

use derive_more::BitAnd;
use serde::{Deserialize, Serialize};
use serde::Serialize;

pub(crate) fn test_gpus(
gpus: impl IntoIterator<Item = GpuSpec>,
Expand Down Expand Up @@ -65,7 +65,8 @@ fn test_gpu(
/// # Ok(())
/// # }
/// ```
#[derive(Clone, Copy, PartialEq, Eq, Debug, BitAnd, Serialize, Deserialize)]
// 将来の互換性保証のため、`Deserialize`は実装するべきではない
#[derive(Clone, Copy, PartialEq, Eq, Debug, BitAnd, Serialize)]
#[non_exhaustive]
pub struct SupportedDevices {
/// CPUが利用可能。
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// エンジンを起動してyukarin_s・yukarin_sa・decodeの推論を行う

use std::collections::HashMap;
use std::sync::LazyLock;
use std::{cmp::min, ffi::CStr};

use assert_cmd::assert::AssertResult;
use libloading::Library;
use serde::{Deserialize, Serialize};
use voicevox_core::SupportedDevices;

use test_util::{c_api::CApi, EXAMPLE_DATA};

Expand All @@ -33,7 +33,9 @@ impl assert_cdylib::TestCase for TestCase {

{
let supported_devices = lib.supported_devices();
serde_json::from_str::<SupportedDevices>(CStr::from_ptr(supported_devices).to_str()?)?;
serde_json::from_str::<HashMap<String, bool>>(
CStr::from_ptr(supported_devices).to_str()?,
)?;
}

assert!(lib.initialize(false, 0, false));
Expand Down
5 changes: 3 additions & 2 deletions crates/voicevox_core_c_api/tests/e2e/testcases/global_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use libloading::Library;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use test_util::c_api::{self, CApi, VoicevoxLoadOnnxruntimeOptions, VoicevoxResultCode};
use voicevox_core::SupportedDevices;

use crate::{
assert_cdylib::{self, case, Utf8Output},
Expand Down Expand Up @@ -65,7 +64,9 @@ impl assert_cdylib::TestCase for TestCase {
supported_devices.as_mut_ptr(),
));
let supported_devices = supported_devices.assume_init();
serde_json::from_str::<SupportedDevices>(CStr::from_ptr(supported_devices).to_str()?)?;
serde_json::from_str::<HashMap<String, bool>>(
CStr::from_ptr(supported_devices).to_str()?,
)?;
lib.voicevox_json_free(supported_devices);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public static String getVersion() {
*
* <p>あくまでONNX Runtimeが対応しているデバイスの情報であることに注意。GPUが使える環境ではなかったとしても {@link #cuda} や {@link #dml} は
* {@code true} を示しうる。
*
* <p>{@code Gson#fromJson} でJSONから変換することはできない。その試みは {@link UnsupportedOperationException} となる。
*/
public static class SupportedDevices {
/**
Expand Down Expand Up @@ -71,9 +73,14 @@ public static class SupportedDevices {
public final boolean dml;

private SupportedDevices() {
this.cpu = false;
this.cuda = false;
this.dml = false;
throw new UnsupportedOperationException("You cannot deserialize `SupportedDevices`");
}

/** accessed only via JNI */
private SupportedDevices(boolean cpu, boolean cuda, boolean dml) {
this.cpu = cpu;
this.cuda = cuda;
this.dml = dml;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import static jp.hiroshiba.voicevoxcore.GlobalInfo.SupportedDevices;

import com.google.gson.Gson;
import jakarta.annotation.Nonnull;
import jakarta.annotation.Nullable;
import java.util.Optional;
Expand Down Expand Up @@ -122,16 +121,10 @@ private Onnxruntime(@Nullable String filename) {
* @return {@link SupportedDevices}。
*/
public SupportedDevices supportedDevices() {
Gson gson = new Gson();
String supportedDevicesJson = rsSupportedDevices();
SupportedDevices supportedDevices = gson.fromJson(supportedDevicesJson, SupportedDevices.class);
if (supportedDevices == null) {
throw new NullPointerException("supported_devices");
}
return supportedDevices;
return rsSupportedDevices();
}

private native void rsNew(@Nullable String filename);

private native String rsSupportedDevices();
private native SupportedDevices rsSupportedDevices();
}
18 changes: 14 additions & 4 deletions crates/voicevox_core_java_api/src/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use jni::{
JNIEnv,
};

use crate::common::throw_if_err;
use crate::{common::throw_if_err, object};

// SAFETY: voicevox_core_java_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
#[duplicate_item(
Expand Down Expand Up @@ -54,8 +54,18 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_blocking_Onnxruntime_rs
let this = *env.get_rust_field::<_, _, &'static voicevox_core::blocking::Onnxruntime>(
&this, "handle",
)?;
let json = this.supported_devices()?.to_json().to_string();
let json = env.new_string(json)?;
Ok(json.into_raw())
let devices = this.supported_devices()?;

assert!(match devices.to_json() {
serde_json::Value::Object(o) => o.len() == 3, // `cpu`, `cuda`, `dml`
_ => false,
});

let devices = env.new_object(
object!("GlobalInfo$SupportedDevices"),
"(ZZZ)V",
&[devices.cpu.into(), devices.cuda.into(), devices.dml.into()],
)?;
Ok(devices.into_raw())
})
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from typing import Literal, NewType, TypeAlias
from typing import Literal, NewType, NoReturn, TypeAlias
from uuid import UUID

import pydantic
from pydantic_core import ArgsKwargs

from .._rust import _to_zenkaku, _validate_pronunciation
from ._please_do_not_use import _Reserved
Expand Down Expand Up @@ -137,6 +138,9 @@ class SupportedDevices:

あくまでONNX Runtimeが対応しているデバイスの情報であることに注意。GPUが使える環境ではなかったとしても
``cuda`` や ``dml`` は ``True`` を示しうる。

JSONからの変換も含め、VOICEVOX CORE以外が作ることはできない。作ろうとした場合
``TypeError`` となる。
Comment on lines +142 to +143
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Java APIのとちょっと文面を変えているのは、Java APIがGSONだけに言及すればいいのに対してPythonのdataclassでは通常のコンストラクタのことがあるから。

"""

cpu: bool
Expand All @@ -162,6 +166,13 @@ class SupportedDevices:
(``DmlExecutionProvider``)に対応する。必要な環境についてはそちらを参照。
"""

@pydantic.model_validator(mode="before")
@staticmethod
def _deny_unless_from_pyo3(data: ArgsKwargs) -> ArgsKwargs:
if "I AM FROM PYO3" not in data.args:
raise TypeError("You cannot deserialize `SupportedDevices`")
return ArgsKwargs((), kwargs=data.kwargs)


AccelerationMode: TypeAlias = Literal["AUTO", "CPU", "GPU"] | _Reserved
"""
Expand Down
18 changes: 17 additions & 1 deletion crates/voicevox_core_python_api/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use pyo3::{
use serde::{de::DeserializeOwned, Serialize};
use serde_json::json;
use uuid::Uuid;
use voicevox_core::{AccelerationMode, AccentPhrase, StyleId, VoiceModelMeta};
use voicevox_core::{AccelerationMode, AccentPhrase, StyleId, SupportedDevices, VoiceModelMeta};

use crate::{
AnalyzeTextError, GetSupportedDevicesError, GpuSupportError, InitInferenceRuntimeError,
Expand Down Expand Up @@ -255,6 +255,22 @@ pub(crate) impl<T> voicevox_core::Result<T> {
}
}

#[ext(SupportedDevicesExt)]
impl SupportedDevices {
pub(crate) fn to_py(self, py: Python<'_>) -> PyResult<&PyAny> {
assert!(match self.to_json() {
serde_json::Value::Object(o) => o.len() == 3, // `cpu`, `cuda`, `dml`
_ => false,
});

let cls = py.import("voicevox_core")?.getattr("SupportedDevices")?;
cls.call(
("I AM FROM PYO3",),
Some([("cpu", self.cpu), ("cuda", self.cuda), ("dml", self.dml)].into_py_dict(py)),
)
}
}

#[ext]
impl<T> std::result::Result<T, uuid::Error> {
fn into_py_value_result(self) -> PyResult<T> {
Expand Down
22 changes: 8 additions & 14 deletions crates/voicevox_core_python_api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ mod blocking {
use voicevox_core::{AccelerationMode, AudioQuery, StyleId, UserDictWord};

use crate::{
convert::VoicevoxCoreResultExt as _, Closable, SingleTasked, VoiceModelFilePyFields,
convert::{SupportedDevicesExt as _, VoicevoxCoreResultExt as _},
Closable, SingleTasked, VoiceModelFilePyFields,
};

#[pyclass]
Expand Down Expand Up @@ -415,12 +416,7 @@ mod blocking {
}

fn supported_devices<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
let class = py
.import("voicevox_core")?
.getattr("SupportedDevices")?
.downcast()?;
let s = self.0.supported_devices().into_py_result(py)?;
crate::convert::to_pydantic_dataclass(s, class)
self.0.supported_devices().into_py_result(py)?.to_py(py)
}
}

Expand Down Expand Up @@ -888,7 +884,10 @@ mod asyncio {
use uuid::Uuid;
use voicevox_core::{AccelerationMode, AudioQuery, StyleId, UserDictWord};

use crate::{convert::VoicevoxCoreResultExt as _, Closable, Tokio, VoiceModelFilePyFields};
use crate::{
convert::{SupportedDevicesExt as _, VoicevoxCoreResultExt as _},
Closable, Tokio, VoiceModelFilePyFields,
};

#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -1017,12 +1016,7 @@ mod asyncio {
}

fn supported_devices<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
let class = py
.import("voicevox_core")?
.getattr("SupportedDevices")?
.downcast()?;
let s = self.0.supported_devices().into_py_result(py)?;
crate::convert::to_pydantic_dataclass(s, class)
self.0.supported_devices().into_py_result(py)?.to_py(py)
}
}

Expand Down