Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Runtimeとモデルのシグネチャを隔離する #675

Merged
merged 48 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
5ff2b59
ONNX Runtimeとモデルのシグネチャを隔離する
qryxip Nov 5, 2023
33245ff
`R: InferenceCore`を`SynthesisEngine`まで持っていく
qryxip Nov 5, 2023
4bd2281
`SupportedDevices::create`の実装を移動
qryxip Nov 5, 2023
f959911
不要なreexportを削除
qryxip Nov 5, 2023
55b04d3
`InferenceModels`の定義を`signatures`に移動
qryxip Nov 5, 2023
c38ad99
`ErrorRepr::GetSupportedDevices`の中身を`anyhow::Error`に
qryxip Nov 6, 2023
192417f
enum-map v3.0.0-beta.1を導入し、`EnumMap`駆動に
qryxip Nov 6, 2023
20db67a
Minor refactor
qryxip Nov 6, 2023
cc84068
Minor refactor
qryxip Nov 6, 2023
e0f29c6
色々再構成
qryxip Nov 8, 2023
cb1db34
Fix up
qryxip Nov 8, 2023
c3e08dd
`OnnxruntimeInferenceBuilder` → `OnnxruntimeRunContext`
qryxip Nov 8, 2023
e4b91ab
`impl SupportsInferenceOutput<_> for Onnxruntime`を拡張
qryxip Nov 8, 2023
8584d27
`SignatureKind` → `InferenceSignatureKind`
qryxip Nov 8, 2023
4795309
`LoadedModels`へのアクセスをメソッド越しにするのを徹底する
qryxip Nov 8, 2023
a5dbbdd
Minor refactor
qryxip Nov 8, 2023
525f4b1
`InferenceInput` → `InferenceInputSignature`
qryxip Nov 8, 2023
26476f5
相互参照
qryxip Nov 8, 2023
fbd7d1c
`fn input`まわりを明瞭にする
qryxip Nov 9, 2023
8b4f3b6
"signature"のkindではなく"model"のkindとする
qryxip Nov 9, 2023
c40afd5
"model"ではなく"inference"と呼ぶ
qryxip Nov 11, 2023
81b5804
ランタイムは任意次元任意個数の入出力ができると仮定する
qryxip Nov 11, 2023
120106b
voicevox_core_macrosを作り、"signatures"の実装をマクロ化
qryxip Nov 11, 2023
590ce48
`AnyTensor` → `OutputTensor`
qryxip Nov 11, 2023
c4d5ebe
`INFERENCE` → `KIND`
qryxip Nov 11, 2023
1b1b7bf
`status`を`infer`下に
qryxip Nov 11, 2023
c39f48c
`trait RunContext`を削除
qryxip Nov 11, 2023
c316209
"kind"を直接"group"と呼ぶことにする
qryxip Nov 11, 2023
2274a34
シグネチャの実行時チェック機構を入れる
qryxip Nov 12, 2023
b7d48f3
signaturesのマクロ化を完了させる
qryxip Nov 13, 2023
b6db1c0
Minor refactor
qryxip Nov 13, 2023
96a93e9
`InferenceGroup` → `InferenceDomain`
qryxip Nov 14, 2023
59d8779
Minor refactor
qryxip Nov 14, 2023
d0dc56f
`InferenceDomain::{INPUT,OUTPUT}_PARAM_INFOS`を統合
qryxip Nov 14, 2023
c654cd1
`InferenceDomain::PARAM_INFOS`にdocstring
qryxip Nov 14, 2023
868d3f6
voicevox_core_macrosにdocstring
qryxip Nov 14, 2023
0998793
`sealed::InputScalar`にFIXME
qryxip Nov 14, 2023
75fd7ac
"Domain"と"Operation"に分離
qryxip Nov 14, 2023
ad222c9
`InferenceOperationKind` → `InferenceOperationImpl`
qryxip Nov 15, 2023
7005c96
docを修正
qryxip Nov 15, 2023
9417992
Merge branch 'main' into split-onnxruntime-and-model-signatures
qryxip Nov 15, 2023
1655719
"voicevox_core内で" → "Rust APIクレート内で"
qryxip Nov 15, 2023
a73f22c
docを追記
qryxip Nov 15, 2023
f17919b
docを追記
qryxip Nov 15, 2023
48bdb1b
`InferenceDomain`のdocを書く
qryxip Nov 16, 2023
9d7d001
不要な文の削除
qryxip Nov 16, 2023
af828eb
Minor refactor
qryxip Nov 16, 2023
b6b7975
`ArrayExt`をマクロ内に押し込める
qryxip Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ pub(crate) trait InferenceRuntime: 'static {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

pub(crate) trait InferenceDomain: Copy + Enum {
pub(crate) trait InferenceDomain {
type Operation: InferenceOperation;
}

pub(crate) trait InferenceOperation: Copy + Enum {
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
/// `{InferenceInputSignature,InferenceOutputSignature}::PARAM_INFOS`を集めたもの。
///
/// マクロ(voicevox_core_macros)で実装される前提。
Expand All @@ -54,7 +58,7 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Domain: InferenceDomain;
type Input: InferenceInputSignature<Signature = Self>;
type Output: InferenceOutputSignature;
const KIND: Self::Domain;
const OPERATION: <Self::Domain as crate::infer::InferenceDomain>::Operation;
}

pub(crate) trait InferenceInputSignature: Send + 'static {
Expand Down
25 changes: 18 additions & 7 deletions crates/voicevox_core/src/infer/domain.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
use enum_map::Enum;
use macros::{InferenceDomain, InferenceInputSignature, InferenceOutputSignature};
use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};

use super::{InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor};
use super::{
InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor,
};

#[derive(Clone, Copy, Enum, InferenceDomain)]
pub(crate) enum InferenceKind {
#[inference_domain(
pub(crate) enum InferenceDomainImpl {}

impl InferenceDomain for InferenceDomainImpl {
type Operation = InferenceOperationImpl;
}

#[derive(Clone, Copy, Enum, InferenceOperation)]
#[inference_operation(
type Domain = InferenceDomainImpl;
)]
pub(crate) enum InferenceOperationImpl {
#[inference_operation(
type Input = PredictDurationInput;
type Output = PredictDurationOutput;
)]
PredictDuration,

#[inference_domain(
#[inference_operation(
type Input = PredictIntonationInput;
type Output = PredictIntonationOutput;
)]
PredictIntonation,

#[inference_domain(
#[inference_operation(
type Input = DecodeInput;
type Output = DecodeOutput;
)]
Expand Down
53 changes: 28 additions & 25 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use std::{

use anyhow::bail;
use educe::Educe;
use enum_map::EnumMap;
use enum_map::{Enum as _, EnumMap};
use itertools::{iproduct, Itertools as _};

use crate::{
error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::ParamInfo,
infer::{InferenceOperation, ParamInfo},
manifest::ModelInnerId,
metas::{SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta},
voice_model::{VoiceModel, VoiceModelId},
Expand All @@ -26,11 +26,11 @@ use super::{

pub(crate) struct Status<R: InferenceRuntime, D: InferenceDomain> {
loaded_models: std::sync::Mutex<LoadedModels<R, D>>,
session_options: EnumMap<D, InferenceSessionOptions>,
session_options: EnumMap<D::Operation, InferenceSessionOptions>,
}

impl<R: InferenceRuntime, D: InferenceDomain> Status<R, D> {
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
pub fn new(session_options: EnumMap<D, InferenceSessionOptions>) -> Self {
pub fn new(session_options: EnumMap<D::Operation, InferenceSessionOptions>) -> Self {
Self {
loaded_models: Default::default(),
session_options,
Expand All @@ -40,7 +40,7 @@ impl<R: InferenceRuntime, D: InferenceDomain> Status<R, D> {
pub async fn load_model(
&self,
model: &VoiceModel,
model_bytes: &EnumMap<D, Vec<u8>>,
model_bytes: &EnumMap<D::Operation, Vec<u8>>,
) -> Result<()> {
self.loaded_models
.lock()
Expand Down Expand Up @@ -241,30 +241,31 @@ impl<R: InferenceRuntime, D: InferenceDomain> LoadedModels<R, D> {
}

struct SessionSet<R: InferenceRuntime, D: InferenceDomain>(
EnumMap<D, Arc<std::sync::Mutex<R::Session>>>,
EnumMap<D::Operation, Arc<std::sync::Mutex<R::Session>>>,
);

impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> {
fn new(
model_bytes: &EnumMap<D, Vec<u8>>,
options: &EnumMap<D, InferenceSessionOptions>,
model_bytes: &EnumMap<D::Operation, Vec<u8>>,
options: &EnumMap<D::Operation, InferenceSessionOptions>,
) -> anyhow::Result<Self> {
let mut sessions = model_bytes
.iter()
.map(|(k, m)| {
let (expected_input_param_infos, expected_output_param_infos) = D::PARAM_INFOS[k];
.map(|(op, model_bytes)| {
let (expected_input_param_infos, expected_output_param_infos) =
<D::Operation as InferenceOperation>::PARAM_INFOS[op];

let (sess, actual_input_param_infos, actual_output_param_infos) =
R::new_session(|| model_file::decrypt(m), options[k])?;
R::new_session(|| model_file::decrypt(model_bytes), options[op])?;

check_param_infos(expected_input_param_infos, &actual_input_param_infos)?;
check_param_infos(expected_output_param_infos, &actual_output_param_infos)?;

Ok((k.into_usize(), std::sync::Mutex::new(sess).into()))
Ok((op.into_usize(), std::sync::Mutex::new(sess).into()))
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;

return Ok(Self(EnumMap::<D, _>::from_fn(|k| {
return Ok(Self(EnumMap::<D::Operation, _>::from_fn(|k| {
sessions.remove(&k.into_usize()).expect("should exist")
})));

Expand Down Expand Up @@ -305,7 +306,7 @@ impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> {
I::Signature: InferenceSignature<Domain = D>,
{
SessionCell {
inner: self.0[I::Signature::KIND].clone(),
inner: self.0[I::Signature::OPERATION].clone(),
marker: PhantomData,
}
}
Expand Down Expand Up @@ -333,8 +334,10 @@ mod tests {
use rstest::rstest;

use crate::{
infer::domain::InferenceKind, macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file,
infer::domain::{InferenceDomainImpl, InferenceOperationImpl},
macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl,
test_util::open_default_vvm_file,
};

use super::{super::InferenceSessionOptions, Status};
Expand All @@ -351,23 +354,23 @@ mod tests {
let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false);
let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu);
let session_options = enum_map! {
InferenceKind::PredictDuration
| InferenceKind::PredictIntonation => light_session_options,
InferenceKind::Decode => heavy_session_options,
InferenceOperationImpl::PredictDuration
| InferenceOperationImpl::PredictIntonation => light_session_options,
InferenceOperationImpl::Decode => heavy_session_options,
};
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(session_options);
let status = Status::<InferenceRuntimeImpl, InferenceDomainImpl>::new(session_options);

assert_eq!(
light_session_options,
status.session_options[InferenceKind::PredictDuration],
status.session_options[InferenceOperationImpl::PredictDuration],
);
assert_eq!(
light_session_options,
status.session_options[InferenceKind::PredictIntonation],
status.session_options[InferenceOperationImpl::PredictIntonation],
);
assert_eq!(
heavy_session_options,
status.session_options[InferenceKind::Decode],
status.session_options[InferenceOperationImpl::Decode],
);

assert!(status.loaded_models.lock().unwrap().0.is_empty());
Expand All @@ -376,7 +379,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_load_model_works() {
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(
let status = Status::<InferenceRuntimeImpl, InferenceDomainImpl>::new(
enum_map!(_ => InferenceSessionOptions::new(0, false)),
);
let model = &open_default_vvm_file().await;
Expand All @@ -389,7 +392,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn status_is_model_loaded_works() {
let status = Status::<InferenceRuntimeImpl, InferenceKind>::new(
let status = Status::<InferenceRuntimeImpl, InferenceDomainImpl>::new(
enum_map!(_ => InferenceSessionOptions::new(0, false)),
);
let vvm = open_default_vvm_file().await;
Expand Down
13 changes: 7 additions & 6 deletions crates/voicevox_core/src/inference_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use enum_map::enum_map;

use crate::infer::{
domain::{
DecodeInput, DecodeOutput, InferenceKind, PredictDurationInput, PredictDurationOutput,
PredictIntonationInput, PredictIntonationOutput,
DecodeInput, DecodeOutput, InferenceDomainImpl, InferenceOperationImpl,
PredictDurationInput, PredictDurationOutput, PredictIntonationInput,
PredictIntonationOutput,
},
status::Status,
InferenceRuntime, InferenceSessionOptions,
Expand All @@ -14,7 +15,7 @@ use super::*;
const PHONEME_LENGTH_MINIMAL: f32 = 0.01;

pub(crate) struct InferenceCore<R: InferenceRuntime> {
status: Status<R, InferenceKind>,
status: Status<R, InferenceDomainImpl>,
}

impl<R: InferenceRuntime> InferenceCore<R> {
Expand All @@ -27,9 +28,9 @@ impl<R: InferenceRuntime> InferenceCore<R> {
let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu);

let status = Status::new(enum_map! {
InferenceKind::PredictDuration
| InferenceKind::PredictIntonation => light_session_options,
InferenceKind::Decode => heavy_session_options,
InferenceOperationImpl::PredictDuration
| InferenceOperationImpl::PredictIntonation => light_session_options,
InferenceOperationImpl::Decode => heavy_session_options,
});
Ok(Self { status })
} else {
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use futures::future::join3;
use serde::{de::DeserializeOwned, Deserialize};

use super::*;
use crate::infer::domain::InferenceKind;
use crate::infer::domain::InferenceOperationImpl;
use std::{
collections::{BTreeMap, HashMap},
io,
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct VoiceModel {
impl VoiceModel {
pub(crate) async fn read_inference_models(
&self,
) -> LoadModelResult<EnumMap<InferenceKind, Vec<u8>>> {
) -> LoadModelResult<EnumMap<InferenceOperationImpl, Vec<u8>>> {
let reader = VvmEntryReader::open(&self.path).await?;
let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) =
join3(
Expand Down
41 changes: 34 additions & 7 deletions crates/voicevox_core_macros/src/inference_domain.rs
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,41 @@ use syn::{
ItemType, Type, Variant,
};

pub(crate) fn derive_inference_domain(
pub(crate) fn derive_inference_operation(
input: &DeriveInput,
) -> syn::Result<proc_macro2::TokenStream> {
let DeriveInput {
attrs,
vis,
ident: domain_name,
ident: operation_ty_name,
generics,
data,
..
} = input;

deny_generics(generics)?;

let AssocTypeDomain(domain_ty) = attrs
.iter()
.find(|a| a.path().is_ident("inference_operation"))
.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"missing `#[inference_operation(…)]`",
)
})?
.parse_args()?;

let variants = unit_enum_variants(data)?
.into_iter()
.map(|(attrs, variant_name)| {
let AssocTypes { input, output } = attrs
.iter()
.find(|a| a.path().is_ident("inference_domain"))
.find(|a| a.path().is_ident("inference_operation"))
.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"missing `#[inference_domain(…)]`",
"missing `#[inference_operation(…)]`",
)
})?
.parse_args()?;
Expand All @@ -47,16 +59,18 @@ pub(crate) fn derive_inference_domain(
#vis enum #variant_name {}

impl crate::infer::InferenceSignature for #variant_name {
type Domain = #domain_name;
type Domain = #domain_ty;
type Input = #input_ty;
type Output = #output_ty;
const KIND: Self::Domain = #domain_name :: #variant_name;

const OPERATION: <Self::Domain as crate::infer::InferenceDomain>::Operation =
#operation_ty_name :: #variant_name;
}
}
});

return Ok(quote! {
impl crate::infer::InferenceDomain for #domain_name {
impl crate::infer::InferenceOperation for #operation_ty_name {
const PARAM_INFOS: ::enum_map::EnumMap<
Self,
(
Expand All @@ -74,6 +88,19 @@ pub(crate) fn derive_inference_domain(
#(#signatures)*
});

struct AssocTypeDomain(Type);

impl Parse for AssocTypeDomain {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let ItemType { ident, ty, .. } = input.parse()?;

if ident != "Domain" {
return Err(syn::Error::new(ident.span(), "expected `Domain`"));
}
Ok(Self(*ty))
}
}

struct AssocTypes {
input: Type,
output: Type,
Expand Down
Loading
Loading