From 5ff2b5948addd5045dfa100ee0c17077993764b7 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 6 Nov 2023 03:28:44 +0900 Subject: [PATCH 01/47] =?UTF-8?q?ONNX=20Runtime=E3=81=A8=E3=83=A2=E3=83=87?= =?UTF-8?q?=E3=83=AB=E3=81=AE=E3=82=B7=E3=82=B0=E3=83=8D=E3=83=81=E3=83=A3?= =?UTF-8?q?=E3=82=92=E9=9A=94=E9=9B=A2=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 1 + Cargo.toml | 1 + crates/voicevox_core/Cargo.toml | 1 + crates/voicevox_core/src/infer.rs | 92 +++++++ crates/voicevox_core/src/infer/runtimes.rs | 3 + .../src/infer/runtimes/onnxruntime.rs | 136 +++++++++ crates/voicevox_core/src/infer/signatures.rs | 87 ++++++ crates/voicevox_core/src/inference_core.rs | 83 +++--- crates/voicevox_core/src/lib.rs | 1 + crates/voicevox_core/src/status.rs | 259 +++--------------- crates/voicevox_core_c_api/Cargo.toml | 2 +- 11 files changed, 397 insertions(+), 269 deletions(-) create mode 100644 crates/voicevox_core/src/infer.rs create mode 100644 crates/voicevox_core/src/infer/runtimes.rs create mode 100644 crates/voicevox_core/src/infer/runtimes/onnxruntime.rs create mode 100644 crates/voicevox_core/src/infer/signatures.rs diff --git a/Cargo.lock b/Cargo.lock index 50868f63b..08e25f93e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4280,6 +4280,7 @@ dependencies = [ "indexmap 2.0.0", "itertools", "nanoid", + "ndarray", "once_cell", "onnxruntime", "open_jtalk", diff --git a/Cargo.toml b/Cargo.toml index b6237098a..bb98404f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ easy-ext = "1.0.1" fs-err = { version = "2.9.0", features = ["tokio"] } futures = "0.3.26" itertools = "0.10.5" +ndarray = "0.15.6" once_cell = "1.18.0" regex = "1.10.0" rstest = "0.15.0" diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 3a23b794a..bee2f822c 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -22,6 +22,7 @@ futures.workspace = true indexmap = { version = "2.0.0", features = ["serde"] } itertools.workspace = true nanoid = "0.4.0" +ndarray.workspace = true once_cell.workspace = true regex.workspace = true serde.workspace = true diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs new file mode 100644 index 000000000..efc66b1e8 --- /dev/null +++ b/crates/voicevox_core/src/infer.rs @@ -0,0 +1,92 @@ +pub(crate) mod runtimes; +pub(crate) mod signatures; + +use std::{fmt::Debug, marker::PhantomData, sync::Arc}; + +use derive_new::new; +use ndarray::{Array, Dimension, LinalgScalar}; +use thiserror::Error; + +pub(crate) trait InferenceRuntime: Copy { + type Session: Session; + type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>; +} + +pub(crate) trait Session: Sized + 'static { + fn new( + model: impl FnOnce() -> std::result::Result, DecryptModelError>, + options: SessionOptions, + ) -> anyhow::Result; +} + +pub(crate) trait RunBuilder<'a>: + From<&'a mut ::Session> +{ + type Runtime: InferenceRuntime; + fn input(&mut self, tensor: Array) -> &mut Self; +} + +pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputScalar {} + +impl InputScalar for i64 {} +impl InputScalar for f32 {} + +pub(crate) trait Signature: Sized + Send + Sync + 'static { + type SessionSet; + type Output; + fn get_session( + session_set: &Self::SessionSet, + ) -> &Arc>>; + fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>); +} + +pub(crate) trait Output: Sized + Send { + fn run(ctx: R::RunBuilder<'_>) -> anyhow::Result; +} + +pub(crate) struct TypedSession { + inner: R::Session, + marker: PhantomData, +} + +impl TypedSession { + pub(crate) fn new( + model: impl FnOnce() -> std::result::Result, DecryptModelError>, + options: SessionOptions, + ) -> anyhow::Result { + let inner = R::Session::new(model, options)?; + Ok(Self { + inner, + marker: PhantomData, + }) + } + + pub(crate) fn run(&mut self, sig: S) -> anyhow::Result + where + S::Output: Output, + { + let mut ctx = R::RunBuilder::from(&mut self.inner); + sig.input(&mut ctx); + S::Output::run(ctx) + } +} + +#[derive(new, Clone, Copy)] +pub(crate) struct SessionOptions { + pub(crate) cpu_num_threads: u16, + pub(crate) use_gpu: bool, +} + +#[derive(Error, Debug)] +#[error("不正なモデルファイルです")] +pub(crate) struct DecryptModelError; + +mod sealed { + pub(crate) trait OnnxruntimeInputScalar: + onnxruntime::TypeToTensorElementDataType + { + } + + impl OnnxruntimeInputScalar for i64 {} + impl OnnxruntimeInputScalar for f32 {} +} diff --git a/crates/voicevox_core/src/infer/runtimes.rs b/crates/voicevox_core/src/infer/runtimes.rs new file mode 100644 index 000000000..7934027b6 --- /dev/null +++ b/crates/voicevox_core/src/infer/runtimes.rs @@ -0,0 +1,3 @@ +mod onnxruntime; + +pub(crate) use self::onnxruntime::Onnxruntime; diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs new file mode 100644 index 000000000..a26abbb74 --- /dev/null +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -0,0 +1,136 @@ +use ndarray::{Array, Dimension}; +use once_cell::sync::Lazy; +use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel}; + +use crate::infer::{ + DecryptModelError, InferenceRuntime, InputScalar, Output, RunBuilder, Session, SessionOptions, +}; + +pub(crate) use self::assert_send::AssertSend; + +#[derive(Clone, Copy)] +pub(crate) enum Onnxruntime {} + +impl InferenceRuntime for Onnxruntime { + type Session = AssertSend>; + type RunBuilder<'a> = OnnxruntimeInferenceBuilder<'a>; +} + +impl Session for AssertSend> { + fn new( + model: impl FnOnce() -> std::result::Result, DecryptModelError>, + options: SessionOptions, + ) -> anyhow::Result { + let mut builder = ENVIRONMENT + .new_session_builder()? + .with_optimization_level(GraphOptimizationLevel::Basic)? + .with_intra_op_num_threads(options.cpu_num_threads.into())? + .with_inter_op_num_threads(options.cpu_num_threads.into())?; + + if options.use_gpu { + #[cfg(feature = "directml")] + { + use onnxruntime::ExecutionMode; + + builder = builder + .with_disable_mem_pattern()? + .with_execution_mode(ExecutionMode::ORT_SEQUENTIAL)? + .with_append_execution_provider_directml(0)?; + } + + #[cfg(not(feature = "directml"))] + { + builder = builder.with_append_execution_provider_cuda(Default::default())?; + } + } + + let model = model()?; + let this = builder.with_model_from_memory(model)?.into(); + return Ok(this); + + static ENVIRONMENT: Lazy = Lazy::new(|| { + Environment::builder() + .with_name(env!("CARGO_PKG_NAME")) + .with_log_level(LOGGING_LEVEL) + .build() + .unwrap() + }); + + const LOGGING_LEVEL: LoggingLevel = if cfg!(debug_assertions) { + LoggingLevel::Verbose + } else { + LoggingLevel::Warning + }; + } +} + +pub(crate) struct OnnxruntimeInferenceBuilder<'sess> { + sess: &'sess mut AssertSend>, + inputs: Vec>, +} + +impl<'sess> From<&'sess mut AssertSend>> + for OnnxruntimeInferenceBuilder<'sess> +{ + fn from(sess: &'sess mut AssertSend>) -> Self { + Self { + sess, + inputs: vec![], + } + } +} + +impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> { + type Runtime = Onnxruntime; + + fn input(&mut self, tensor: Array) -> &mut Self { + self.inputs + .push(Box::new(onnxruntime::session::NdArray::new(tensor))); + self + } +} + +impl Output for (Vec,) { + fn run( + OnnxruntimeInferenceBuilder { sess, mut inputs }: OnnxruntimeInferenceBuilder<'_>, + ) -> anyhow::Result { + let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; + + // FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く + Ok((outputs[0].as_slice().unwrap().to_owned(),)) + } +} + +// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 +// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 +mod assert_send { + use std::ops::{Deref, DerefMut}; + + pub(crate) struct AssertSend(T); + + impl From> + for AssertSend> + { + fn from(session: onnxruntime::session::Session<'static>) -> Self { + Self(session) + } + } + + impl Deref for AssertSend { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for AssertSend { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + // SAFETY: `Session` is probably "send"able. + #[allow(unsafe_code)] + unsafe impl Send for AssertSend {} +} diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs new file mode 100644 index 000000000..764d70b8d --- /dev/null +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use ndarray::{Array0, Array1, Array2}; + +use crate::infer::{InferenceRuntime, RunBuilder, Signature, TypedSession}; + +pub(crate) struct SessionSet { + pub(crate) predict_duration: Arc>>, + pub(crate) predict_intonation: Arc>>, + pub(crate) decode: Arc>>, +} + +pub(crate) struct PredictDuration { + pub(crate) phoneme: Array1, + pub(crate) speaker_id: Array1, +} + +impl Signature for PredictDuration { + type SessionSet = SessionSet; + type Output = (Vec,); + + fn get_session( + session_set: &Self::SessionSet, + ) -> &Arc>> { + &session_set.predict_duration + } + + fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { + ctx.input(self.phoneme).input(self.speaker_id); + } +} + +pub(crate) struct PredictIntonation { + pub(crate) length: Array0, + pub(crate) vowel_phoneme: Array1, + pub(crate) consonant_phoneme: Array1, + pub(crate) start_accent: Array1, + pub(crate) end_accent: Array1, + pub(crate) start_accent_phrase: Array1, + pub(crate) end_accent_phrase: Array1, + pub(crate) speaker_id: Array1, +} + +impl Signature for PredictIntonation { + type SessionSet = SessionSet; + type Output = (Vec,); + + fn get_session( + session_set: &Self::SessionSet, + ) -> &Arc>> { + &session_set.predict_intonation + } + + fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { + ctx.input(self.length) + .input(self.vowel_phoneme) + .input(self.consonant_phoneme) + .input(self.start_accent) + .input(self.end_accent) + .input(self.start_accent_phrase) + .input(self.end_accent_phrase) + .input(self.speaker_id); + } +} + +pub(crate) struct Decode { + pub(crate) f0: Array2, + pub(crate) phoneme: Array2, + pub(crate) speaker_id: Array1, +} + +impl Signature for Decode { + type SessionSet = SessionSet; + type Output = (Vec,); + + fn get_session( + session_set: &Self::SessionSet, + ) -> &Arc>> { + &session_set.decode + } + + fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { + ctx.input(self.f0) + .input(self.phoneme) + .input(self.speaker_id); + } +} diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 4b0d08be2..d1c7831c5 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,6 +1,6 @@ use self::status::*; use super::*; -use onnxruntime::{ndarray, session::NdArray}; +use crate::infer::signatures::{Decode, PredictDuration, PredictIntonation}; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; @@ -60,12 +60,15 @@ impl InferenceCore { let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let phoneme_vector_array = NdArray::new(ndarray::arr1(phoneme_vector)); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); - - let mut output = self + let (mut output,) = self .status - .predict_duration_session_run(&model_id, phoneme_vector_array, speaker_id_array) + .run_session( + &model_id, + PredictDuration { + phoneme: ndarray::arr1(phoneme_vector), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, + ) .await?; for output_item in output.iter_mut() { @@ -95,29 +98,24 @@ impl InferenceCore { let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let length_array = NdArray::new(ndarray::arr0(length as i64)); - let vowel_phoneme_vector_array = NdArray::new(ndarray::arr1(vowel_phoneme_vector)); - let consonant_phoneme_vector_array = NdArray::new(ndarray::arr1(consonant_phoneme_vector)); - let start_accent_vector_array = NdArray::new(ndarray::arr1(start_accent_vector)); - let end_accent_vector_array = NdArray::new(ndarray::arr1(end_accent_vector)); - let start_accent_phrase_vector_array = - NdArray::new(ndarray::arr1(start_accent_phrase_vector)); - let end_accent_phrase_vector_array = NdArray::new(ndarray::arr1(end_accent_phrase_vector)); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); - - self.status - .predict_intonation_session_run( + let (output,) = self + .status + .run_session( &model_id, - length_array, - vowel_phoneme_vector_array, - consonant_phoneme_vector_array, - start_accent_vector_array, - end_accent_vector_array, - start_accent_phrase_vector_array, - end_accent_phrase_vector_array, - speaker_id_array, + PredictIntonation { + length: ndarray::arr0(length as i64), + vowel_phoneme: ndarray::arr1(vowel_phoneme_vector), + consonant_phoneme: ndarray::arr1(consonant_phoneme_vector), + start_accent: ndarray::arr1(start_accent_vector), + end_accent: ndarray::arr1(end_accent_vector), + start_accent_phrase: ndarray::arr1(start_accent_phrase_vector), + end_accent_phrase: ndarray::arr1(end_accent_phrase_vector), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, ) - .await + .await?; + + Ok(output) } pub async fn decode( @@ -150,22 +148,23 @@ impl InferenceCore { padding_size, ); - let f0_array = NdArray::new( - ndarray::arr1(&f0_with_padding) - .into_shape([length_with_padding, 1]) - .unwrap(), - ); - let phoneme_array = NdArray::new( - ndarray::arr1(&phoneme_with_padding) - .into_shape([length_with_padding, phoneme_size]) - .unwrap(), - ); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); + let (output,) = self + .status + .run_session( + &model_id, + Decode { + f0: ndarray::arr1(&f0_with_padding) + .into_shape([length_with_padding, 1]) + .unwrap(), + phoneme: ndarray::arr1(&phoneme_with_padding) + .into_shape([length_with_padding, phoneme_size]) + .unwrap(), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, + ) + .await?; - self.status - .decode_session_run(&model_id, f0_array, phoneme_array, speaker_id_array) - .await - .map(|output| Self::trim_padding_from_output(output, padding_size)) + Ok(Self::trim_padding_from_output(output, padding_size)) } fn make_f0_with_padding( diff --git a/crates/voicevox_core/src/lib.rs b/crates/voicevox_core/src/lib.rs index 798515fb9..407f0b8f4 100644 --- a/crates/voicevox_core/src/lib.rs +++ b/crates/voicevox_core/src/lib.rs @@ -6,6 +6,7 @@ mod devices; /// cbindgen:ignore mod engine; mod error; +mod infer; mod inference_core; mod macros; mod manifest; diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 64a402683..46e462d1a 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,23 +1,16 @@ use super::*; -use itertools::iproduct; -use once_cell::sync::Lazy; -use onnxruntime::{ - environment::Environment, - ndarray::{Ix0, Ix1, Ix2}, - session::{NdArray, Session}, - GraphOptimizationLevel, LoggingLevel, +use crate::infer::{ + runtimes::Onnxruntime, + signatures::{Decode, PredictDuration, PredictIntonation, SessionSet}, + DecryptModelError, Output, SessionOptions, Signature, TypedSession, }; +use derive_more::Index; +use itertools::iproduct; +use std::path::Path; use std::sync::Arc; -use std::{env, path::Path}; -use tracing::error; mod model_file; -cfg_if! { - if #[cfg(not(feature="directml"))]{ - use onnxruntime::CudaProviderOptions; - } -} use std::collections::BTreeMap; pub struct Status { @@ -26,31 +19,6 @@ pub struct Status { heavy_session_options: SessionOptions, // 重いモデルはこちらを使う } -#[derive(new, Getters)] -struct SessionOptions { - cpu_num_threads: u16, - use_gpu: bool, -} - -#[derive(thiserror::Error, Debug)] -#[error("不正なモデルファイルです")] -struct DecryptModelError; - -static ENVIRONMENT: Lazy = Lazy::new(|| { - cfg_if! { - if #[cfg(debug_assertions)]{ - const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Verbose; - } else{ - const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Warning; - } - } - Environment::builder() - .with_name(env!("CARGO_PKG_NAME")) - .with_log_level(LOGGING_LEVEL) - .build() - .unwrap() -}); - impl Status { pub fn new(use_gpu: bool, cpu_num_threads: u16) -> Self { Self { @@ -116,13 +84,13 @@ impl Status { self.loaded_models.lock().unwrap().contains_style(style_id) } - fn new_session( + fn new_session( &self, model: &[u8], session_options: &SessionOptions, path: impl AsRef, - ) -> LoadModelResult> { - self.new_session_from_bytes(|| model_file::decrypt(model), session_options) + ) -> LoadModelResult> { + TypedSession::::new(|| model_file::decrypt(model), *session_options) .map_err(|source| LoadModelError { path: path.as_ref().to_owned(), context: LoadModelErrorKind::InvalidModelData, @@ -130,36 +98,6 @@ impl Status { }) } - fn new_session_from_bytes( - &self, - model_bytes: impl FnOnce() -> std::result::Result, DecryptModelError>, - session_options: &SessionOptions, - ) -> anyhow::Result> { - let session_builder = ENVIRONMENT - .new_session_builder()? - .with_optimization_level(GraphOptimizationLevel::Basic)? - .with_intra_op_num_threads(*session_options.cpu_num_threads() as i32)? - .with_inter_op_num_threads(*session_options.cpu_num_threads() as i32)?; - - let session_builder = if *session_options.use_gpu() { - cfg_if! { - if #[cfg(feature = "directml")]{ - session_builder - .with_disable_mem_pattern()? - .with_execution_mode(onnxruntime::ExecutionMode::ORT_SEQUENTIAL)? - .with_append_execution_provider_directml(0)? - } else { - let options = CudaProviderOptions::default(); - session_builder.with_append_execution_provider_cuda(options)? - } - } - } else { - session_builder - }; - - Ok(session_builder.with_model_from_memory(model_bytes()?)?) - } - pub fn validate_speaker_id(&self, style_id: StyleId) -> bool { self.is_loaded_model_by_style_id(style_id) } @@ -167,102 +105,25 @@ impl Status { /// # Panics /// /// `self`が`model_id`を含んでいないとき、パニックする。 - pub async fn predict_duration_session_run( - &self, - model_id: &VoiceModelId, - mut phoneme_vector_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let predict_duration = self.loaded_models.lock().unwrap().get( - model_id, - |SessionSet { - predict_duration, .. - }| predict_duration, - ); - - tokio::task::spawn_blocking(move || { - let mut predict_duration = predict_duration.lock().unwrap(); - - let output_tensors = predict_duration - .run(vec![&mut phoneme_vector_array, &mut speaker_id_array]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - }) - .await - .unwrap() - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - #[allow(clippy::too_many_arguments)] - pub async fn predict_intonation_session_run( + pub(crate) async fn run_session( &self, model_id: &VoiceModelId, - mut length_array: NdArray, - mut vowel_phoneme_vector_array: NdArray, - mut consonant_phoneme_vector_array: NdArray, - mut start_accent_vector_array: NdArray, - mut end_accent_vector_array: NdArray, - mut start_accent_phrase_vector_array: NdArray, - mut end_accent_phrase_vector_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let predict_intonation = self.loaded_models.lock().unwrap().get( - model_id, - |SessionSet { - predict_intonation, .. - }| predict_intonation, - ); + input: S, + ) -> Result + where + S: Signature, + for<'a> &'a S::SessionSet: From<&'a SessionSet>, + S::Output: Output, + { + let sess = S::get_session::( + (&self.loaded_models.lock().unwrap()[model_id].session_set).into(), + ) + .clone(); tokio::task::spawn_blocking(move || { - let mut predict_intonation = predict_intonation.lock().unwrap(); - - let output_tensors = predict_intonation - .run(vec![ - &mut length_array, - &mut vowel_phoneme_vector_array, - &mut consonant_phoneme_vector_array, - &mut start_accent_vector_array, - &mut end_accent_vector_array, - &mut start_accent_phrase_vector_array, - &mut end_accent_phrase_vector_array, - &mut speaker_id_array, - ]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - }) - .await - .unwrap() - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - pub async fn decode_session_run( - &self, - model_id: &VoiceModelId, - mut f0_array: NdArray, - mut phoneme_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let decode = self - .loaded_models - .lock() - .unwrap() - .get(model_id, |SessionSet { decode, .. }| decode); - - tokio::task::spawn_blocking(move || { - let mut decode = decode.lock().unwrap(); - - let output_tensors = decode - .run(vec![ - &mut f0_array, - &mut phoneme_array, - &mut speaker_id_array, - ]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) + let mut sess = sess.lock().unwrap(); + sess.run(input) + .map_err(|e| ErrorRepr::InferenceFailed(e).into()) }) .await .unwrap() @@ -272,13 +133,13 @@ impl Status { /// 読み込んだモデルの`Session`とそのメタ情報を保有し、追加/削除/取得の操作を提供する。 /// /// この構造体のメソッドは、すべて一瞬で完了すべきである。 -#[derive(Default)] +#[derive(Default, Index)] struct LoadedModels(BTreeMap); struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: SessionSet, + session_set: SessionSet, } impl LoadedModels { @@ -314,17 +175,6 @@ impl LoadedModels { Ok((model_id.clone(), model_inner_id)) } - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - fn get( - &self, - model_id: &VoiceModelId, - which: fn(&SessionSet) -> &Arc>>>, - ) -> Arc>>> { - which(&self.0[model_id].session_set).clone() - } - fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool { self.0.contains_key(model_id) } @@ -366,9 +216,9 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - predict_duration: Session<'static>, - predict_intonation: Session<'static>, - decode: Session<'static>, + predict_duration: TypedSession, + predict_intonation: TypedSession, + decode: TypedSession, ) -> Result<()> { self.ensure_acceptable(model)?; @@ -378,9 +228,9 @@ impl LoadedModels { model_inner_ids: model.model_inner_ids(), metas: model.metas().clone(), session_set: SessionSet { - predict_duration: Arc::new(std::sync::Mutex::new(predict_duration.into())), - predict_intonation: Arc::new(std::sync::Mutex::new(predict_intonation.into())), - decode: Arc::new(std::sync::Mutex::new(decode.into())), + predict_duration: Arc::new(std::sync::Mutex::new(predict_duration)), + predict_intonation: Arc::new(std::sync::Mutex::new(predict_intonation)), + decode: Arc::new(std::sync::Mutex::new(decode)), }, }, ); @@ -406,49 +256,6 @@ impl LoadedModels { } } -struct SessionSet { - predict_duration: Arc>>>, - predict_intonation: Arc>>>, - decode: Arc>>>, -} - -// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 -// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 - -use self::assert_send::AssertSend; - -mod assert_send { - use std::ops::{Deref, DerefMut}; - - use onnxruntime::session::Session; - - pub(super) struct AssertSend(T); - - impl From> for AssertSend> { - fn from(session: Session<'static>) -> Self { - Self(session) - } - } - - impl Deref for AssertSend { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } - } - - impl DerefMut for AssertSend { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } - } - - // SAFETY: `Session` is probably "send"able. - #[allow(unsafe_code)] - unsafe impl Send for AssertSend {} -} - #[cfg(test)] mod tests { diff --git a/crates/voicevox_core_c_api/Cargo.toml b/crates/voicevox_core_c_api/Cargo.toml index f187f0001..fad0e1b7b 100644 --- a/crates/voicevox_core_c_api/Cargo.toml +++ b/crates/voicevox_core_c_api/Cargo.toml @@ -52,7 +52,7 @@ easy-ext.workspace = true inventory = "0.3.4" libloading = "0.7.3" libtest-mimic = "0.6.0" -ndarray = "0.15.6" +ndarray.workspace = true ndarray-stats = "0.5.1" regex.workspace = true serde.workspace = true From 33245ff66bab6d97cb980c1f810ed2f73f3af056 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 6 Nov 2023 04:02:52 +0900 Subject: [PATCH 02/47] =?UTF-8?q?`R:=20InferenceCore`=E3=82=92`SynthesisEn?= =?UTF-8?q?gine`=E3=81=BE=E3=81=A7=E6=8C=81=E3=81=A3=E3=81=A6=E3=81=84?= =?UTF-8?q?=E3=81=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 26 +++++++++ crates/voicevox_core/Cargo.toml | 1 + .../src/engine/synthesis_engine.rs | 52 +++++++----------- crates/voicevox_core/src/infer.rs | 6 +-- .../src/infer/runtimes/onnxruntime.rs | 2 +- crates/voicevox_core/src/inference_core.rs | 15 ++++-- crates/voicevox_core/src/status.rs | 54 ++++++++++--------- crates/voicevox_core/src/synthesizer.rs | 9 +++- 8 files changed, 96 insertions(+), 69 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 08e25f93e..eb7c9cb65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1205,6 +1205,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49457524c7e65648794c98283282a0b7c73b10018e7091f1cdcfff314fd7ae59" +[[package]] +name = "educe" +version = "0.4.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f0042ff8246a363dbe77d2ceedb073339e85a804b9a47636c6e016a9a32c05f" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 1.0.102", +] + [[package]] name = "either" version = "1.8.0" @@ -1226,6 +1238,19 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-ordinalize" +version = "3.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bf1fa3f06bbff1ea5b1a9c7b14aa992a39657db60a2759457328d7e058f49ee" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "env_logger" version = "0.9.1" @@ -4273,6 +4298,7 @@ dependencies = [ "derive_more", "duplicate", "easy-ext", + "educe", "fs-err", "futures", "heck", diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index bee2f822c..ecaa495e7 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -17,6 +17,7 @@ derive-new = "0.5.9" derive_more.workspace = true duplicate = "1.0.0" easy-ext.workspace = true +educe = "0.4.23" fs-err.workspace = true futures.workspace = true indexmap = { version = "2.0.0", features = ["serde"] } diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index 22ced6f84..8db50b604 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use super::full_context_label::Utterance; use super::open_jtalk::OpenJtalk; use super::*; +use crate::infer::{InferenceRuntime, Output}; use crate::numerics::F32Ext as _; use crate::InferenceCore; @@ -15,18 +16,19 @@ const MORA_PHONEME_LIST: &[&str] = &[ ]; #[derive(new)] -pub struct SynthesisEngine { - inference_core: InferenceCore, +pub(crate) struct SynthesisEngine { + inference_core: InferenceCore, open_jtalk: Arc, } -#[allow(unsafe_code)] -unsafe impl Send for SynthesisEngine {} - -impl SynthesisEngine { +impl SynthesisEngine +where + R: InferenceRuntime, + (Vec,): Output, +{ pub const DEFAULT_SAMPLING_RATE: u32 = 24000; - pub fn inference_core(&self) -> &InferenceCore { + pub fn inference_core(&self) -> &InferenceCore { &self.inference_core } @@ -123,7 +125,7 @@ impl SynthesisEngine { accent_phrases: &[AccentPhraseModel], style_id: StyleId, ) -> Result> { - let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases); + let (_, phoneme_data_list) = Self::initial_process(accent_phrases); let (_, _, vowel_indexes_data) = split_mora(&phoneme_data_list); @@ -185,7 +187,7 @@ impl SynthesisEngine { accent_phrases: &[AccentPhraseModel], style_id: StyleId, ) -> Result> { - let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases); + let (_, phoneme_data_list) = Self::initial_process(accent_phrases); let mut base_start_accent_list = vec![0]; let mut base_end_accent_list = vec![0]; @@ -193,28 +195,12 @@ impl SynthesisEngine { let mut base_end_accent_phrase_list = vec![0]; for accent_phrase in accent_phrases { let mut accent = usize::from(*accent_phrase.accent() != 1); - SynthesisEngine::create_one_accent_list( - &mut base_start_accent_list, - accent_phrase, - accent as i32, - ); + Self::create_one_accent_list(&mut base_start_accent_list, accent_phrase, accent as i32); accent = *accent_phrase.accent() - 1; - SynthesisEngine::create_one_accent_list( - &mut base_end_accent_list, - accent_phrase, - accent as i32, - ); - SynthesisEngine::create_one_accent_list( - &mut base_start_accent_phrase_list, - accent_phrase, - 0, - ); - SynthesisEngine::create_one_accent_list( - &mut base_end_accent_phrase_list, - accent_phrase, - -1, - ); + Self::create_one_accent_list(&mut base_end_accent_list, accent_phrase, accent as i32); + Self::create_one_accent_list(&mut base_start_accent_phrase_list, accent_phrase, 0); + Self::create_one_accent_list(&mut base_end_accent_phrase_list, accent_phrase, -1); } base_start_accent_list.push(0); base_end_accent_list.push(0); @@ -328,7 +314,7 @@ impl SynthesisEngine { query.accent_phrases().clone() }; - let (flatten_moras, phoneme_data_list) = SynthesisEngine::initial_process(&accent_phrases); + let (flatten_moras, phoneme_data_list) = Self::initial_process(&accent_phrases); let mut phoneme_length_list = vec![pre_phoneme_length]; let mut f0_list = vec![0.]; @@ -647,12 +633,12 @@ mod tests { use ::test_util::OPEN_JTALK_DIC_DIR; use pretty_assertions::assert_eq; - use crate::*; + use crate::{infer::runtimes::Onnxruntime, *}; #[rstest] #[tokio::test] async fn is_openjtalk_dict_loaded_works() { - let core = InferenceCore::new(false, 0).unwrap(); + let core = InferenceCore::::new(false, 0).unwrap(); let synthesis_engine = SynthesisEngine::new(core, OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap().into()); @@ -662,7 +648,7 @@ mod tests { #[rstest] #[tokio::test] async fn create_accent_phrases_works() { - let core = InferenceCore::new(false, 0).unwrap(); + let core = InferenceCore::::new(false, 0).unwrap(); let model = &VoiceModel::sample().await.unwrap(); core.load_model(model).await.unwrap(); diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index efc66b1e8..07a9f4b55 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -1,18 +1,18 @@ pub(crate) mod runtimes; pub(crate) mod signatures; -use std::{fmt::Debug, marker::PhantomData, sync::Arc}; +use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc}; use derive_new::new; use ndarray::{Array, Dimension, LinalgScalar}; use thiserror::Error; -pub(crate) trait InferenceRuntime: Copy { +pub(crate) trait InferenceRuntime: Copy + Ord + Hash + Debug + 'static { type Session: Session; type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>; } -pub(crate) trait Session: Sized + 'static { +pub(crate) trait Session: Sized + Send + 'static { fn new( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: SessionOptions, diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index a26abbb74..9c33f67fc 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -8,7 +8,7 @@ use crate::infer::{ pub(crate) use self::assert_send::AssertSend; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub(crate) enum Onnxruntime {} impl InferenceRuntime for Onnxruntime { diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index d1c7831c5..413c339cc 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,14 +1,21 @@ use self::status::*; use super::*; -use crate::infer::signatures::{Decode, PredictDuration, PredictIntonation}; +use crate::infer::{ + signatures::{Decode, PredictDuration, PredictIntonation}, + InferenceRuntime, Output, +}; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; -pub struct InferenceCore { - status: Status, +pub(crate) struct InferenceCore { + status: Status, } -impl InferenceCore { +impl InferenceCore +where + R: InferenceRuntime, + (Vec,): Output, +{ pub(crate) fn new(use_gpu: bool, cpu_num_threads: u16) -> Result { if !use_gpu || Self::can_support_gpu_feature()? { let status = Status::new(use_gpu, cpu_num_threads); diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 46e462d1a..1e4c7305b 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,10 +1,10 @@ use super::*; use crate::infer::{ - runtimes::Onnxruntime, signatures::{Decode, PredictDuration, PredictIntonation, SessionSet}, - DecryptModelError, Output, SessionOptions, Signature, TypedSession, + DecryptModelError, InferenceRuntime, Output, SessionOptions, Signature, TypedSession, }; use derive_more::Index; +use educe::Educe; use itertools::iproduct; use std::path::Path; use std::sync::Arc; @@ -13,13 +13,13 @@ mod model_file; use std::collections::BTreeMap; -pub struct Status { - loaded_models: std::sync::Mutex, +pub(crate) struct Status { + loaded_models: std::sync::Mutex>, light_session_options: SessionOptions, // 軽いモデルはこちらを使う heavy_session_options: SessionOptions, // 重いモデルはこちらを使う } -impl Status { +impl Status { pub fn new(use_gpu: bool, cpu_num_threads: u16) -> Self { Self { loaded_models: Default::default(), @@ -89,13 +89,14 @@ impl Status { model: &[u8], session_options: &SessionOptions, path: impl AsRef, - ) -> LoadModelResult> { - TypedSession::::new(|| model_file::decrypt(model), *session_options) - .map_err(|source| LoadModelError { + ) -> LoadModelResult> { + TypedSession::::new(|| model_file::decrypt(model), *session_options).map_err( + |source| LoadModelError { path: path.as_ref().to_owned(), context: LoadModelErrorKind::InvalidModelData, source: Some(source), - }) + }, + ) } pub fn validate_speaker_id(&self, style_id: StyleId) -> bool { @@ -112,13 +113,12 @@ impl Status { ) -> Result where S: Signature, - for<'a> &'a S::SessionSet: From<&'a SessionSet>, - S::Output: Output, + for<'a> &'a S::SessionSet: From<&'a SessionSet>, + S::Output: Output, { - let sess = S::get_session::( - (&self.loaded_models.lock().unwrap()[model_id].session_set).into(), - ) - .clone(); + let sess = + S::get_session::((&self.loaded_models.lock().unwrap()[model_id].session_set).into()) + .clone(); tokio::task::spawn_blocking(move || { let mut sess = sess.lock().unwrap(); @@ -133,16 +133,17 @@ impl Status { /// 読み込んだモデルの`Session`とそのメタ情報を保有し、追加/削除/取得の操作を提供する。 /// /// この構造体のメソッドは、すべて一瞬で完了すべきである。 -#[derive(Default, Index)] -struct LoadedModels(BTreeMap); +#[derive(Educe, Index)] +#[educe(Default(bound = "R: InferenceRuntime"))] +struct LoadedModels(BTreeMap>); -struct LoadedModel { +struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: SessionSet, + session_set: SessionSet, } -impl LoadedModels { +impl LoadedModels { fn metas(&self) -> VoiceModelMeta { self.0 .values() @@ -216,9 +217,9 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - predict_duration: TypedSession, - predict_intonation: TypedSession, - decode: TypedSession, + predict_duration: TypedSession, + predict_intonation: TypedSession, + decode: TypedSession, ) -> Result<()> { self.ensure_acceptable(model)?; @@ -260,6 +261,7 @@ impl LoadedModels { mod tests { use super::*; + use crate::infer::runtimes::Onnxruntime; use crate::macros::tests::assert_debug_fmt_eq; use pretty_assertions::assert_eq; @@ -272,7 +274,7 @@ mod tests { #[case(false, 8)] #[case(false, 0)] fn status_new_works(#[case] use_gpu: bool, #[case] cpu_num_threads: u16) { - let status = Status::new(use_gpu, cpu_num_threads); + let status = Status::::new(use_gpu, cpu_num_threads); assert_eq!(false, status.light_session_options.use_gpu); assert_eq!(use_gpu, status.heavy_session_options.use_gpu); assert_eq!( @@ -289,7 +291,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_load_model_works() { - let status = Status::new(false, 0); + let status = Status::::new(false, 0); let result = status.load_model(&open_default_vvm_file().await).await; assert_debug_fmt_eq!(Ok(()), result); assert_eq!(1, status.loaded_models.lock().unwrap().0.len()); @@ -298,7 +300,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_is_model_loaded_works() { - let status = Status::new(false, 0); + let status = Status::::new(false, 0); let vvm = open_default_vvm_file().await; assert!( !status.is_loaded_model(vvm.id()), diff --git a/crates/voicevox_core/src/synthesizer.rs b/crates/voicevox_core/src/synthesizer.rs index 98c3a5f82..178a724c2 100644 --- a/crates/voicevox_core/src/synthesizer.rs +++ b/crates/voicevox_core/src/synthesizer.rs @@ -1,6 +1,9 @@ use std::sync::Arc; -use crate::engine::{create_kana, parse_kana, AccentPhraseModel, OpenJtalk, SynthesisEngine}; +use crate::{ + engine::{create_kana, parse_kana, AccentPhraseModel, OpenJtalk, SynthesisEngine}, + infer::runtimes::Onnxruntime, +}; use super::*; @@ -67,9 +70,11 @@ pub struct InitializeOptions { pub cpu_num_threads: u16, } +type SynthesizerInferenceRuntime = Onnxruntime; + /// 音声シンセサイザ。 pub struct Synthesizer { - synthesis_engine: SynthesisEngine, + synthesis_engine: SynthesisEngine, use_gpu: bool, } From 4bd228149dd748c25e763669cded4796dd5a9d35 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 6 Nov 2023 04:26:15 +0900 Subject: [PATCH 03/47] =?UTF-8?q?`SupportedDevices::create`=E3=81=AE?= =?UTF-8?q?=E5=AE=9F=E8=A3=85=E3=82=92=E7=A7=BB=E5=8B=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/devices.rs | 26 ++++------------ crates/voicevox_core/src/infer.rs | 3 ++ .../src/infer/runtimes/onnxruntime.rs | 30 +++++++++++++++++-- crates/voicevox_core/src/synthesizer.rs | 4 +-- 4 files changed, 38 insertions(+), 25 deletions(-) diff --git a/crates/voicevox_core/src/devices.rs b/crates/voicevox_core/src/devices.rs index 70847cb81..545b5e485 100644 --- a/crates/voicevox_core/src/devices.rs +++ b/crates/voicevox_core/src/devices.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use super::*; +use crate::{infer::InferenceRuntime, synthesizer::InferenceRuntimeImpl}; /// このライブラリで利用可能なデバイスの情報。 /// @@ -11,21 +12,21 @@ pub struct SupportedDevices { /// CPUが利用可能。 /// /// 常に`true`。 - cpu: bool, + pub cpu: bool, /// CUDAが利用可能。 /// /// ONNX Runtimeの[CUDA Execution Provider] (`CUDAExecutionProvider`)に対応する。必要な環境につ /// いてはそちらを参照。 /// /// [CUDA Execution Provider]: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html - cuda: bool, + pub cuda: bool, /// DirectMLが利用可能。 /// /// ONNX Runtimeの[DirectML Execution Provider] (`DmlExecutionProvider`)に対応する。必要な環境に /// ついてはそちらを参照。 /// /// [DirectML Execution Provider]: https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html - dml: bool, + pub dml: bool, } impl SupportedDevices { @@ -42,24 +43,7 @@ impl SupportedDevices { /// # Result::<_, anyhow::Error>::Ok(()) /// ``` pub fn create() -> Result { - let mut cuda_support = false; - let mut dml_support = false; - for provider in onnxruntime::session::get_available_providers() - .map_err(ErrorRepr::GetSupportedDevices)? - .iter() - { - match provider.as_str() { - "CUDAExecutionProvider" => cuda_support = true, - "DmlExecutionProvider" => dml_support = true, - _ => {} - } - } - - Ok(SupportedDevices { - cpu: true, - cuda: cuda_support, - dml: dml_support, - }) + ::supported_devices() } pub fn to_json(&self) -> serde_json::Value { diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 07a9f4b55..92bfc9d5e 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -7,9 +7,12 @@ use derive_new::new; use ndarray::{Array, Dimension, LinalgScalar}; use thiserror::Error; +use crate::SupportedDevices; + pub(crate) trait InferenceRuntime: Copy + Ord + Hash + Debug + 'static { type Session: Session; type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>; + fn supported_devices() -> crate::Result; } pub(crate) trait Session: Sized + Send + 'static { diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 9c33f67fc..6efed2fa1 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -2,8 +2,13 @@ use ndarray::{Array, Dimension}; use once_cell::sync::Lazy; use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel}; -use crate::infer::{ - DecryptModelError, InferenceRuntime, InputScalar, Output, RunBuilder, Session, SessionOptions, +use crate::{ + devices::SupportedDevices, + error::ErrorRepr, + infer::{ + DecryptModelError, InferenceRuntime, InputScalar, Output, RunBuilder, Session, + SessionOptions, + }, }; pub(crate) use self::assert_send::AssertSend; @@ -14,6 +19,27 @@ pub(crate) enum Onnxruntime {} impl InferenceRuntime for Onnxruntime { type Session = AssertSend>; type RunBuilder<'a> = OnnxruntimeInferenceBuilder<'a>; + + fn supported_devices() -> crate::Result { + let mut cuda_support = false; + let mut dml_support = false; + for provider in onnxruntime::session::get_available_providers() + .map_err(ErrorRepr::GetSupportedDevices)? + .iter() + { + match provider.as_str() { + "CUDAExecutionProvider" => cuda_support = true, + "DmlExecutionProvider" => dml_support = true, + _ => {} + } + } + + Ok(SupportedDevices { + cpu: true, + cuda: cuda_support, + dml: dml_support, + }) + } } impl Session for AssertSend> { diff --git a/crates/voicevox_core/src/synthesizer.rs b/crates/voicevox_core/src/synthesizer.rs index 178a724c2..88f419476 100644 --- a/crates/voicevox_core/src/synthesizer.rs +++ b/crates/voicevox_core/src/synthesizer.rs @@ -70,11 +70,11 @@ pub struct InitializeOptions { pub cpu_num_threads: u16, } -type SynthesizerInferenceRuntime = Onnxruntime; +pub(crate) type InferenceRuntimeImpl = Onnxruntime; /// 音声シンセサイザ。 pub struct Synthesizer { - synthesis_engine: SynthesisEngine, + synthesis_engine: SynthesisEngine, use_gpu: bool, } From f9599112c6940be49aaa4f837ea863a5cead47ff Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 6 Nov 2023 04:47:38 +0900 Subject: [PATCH 04/47] =?UTF-8?q?=E4=B8=8D=E8=A6=81=E3=81=AAreexport?= =?UTF-8?q?=E3=82=92=E5=89=8A=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer/runtimes/onnxruntime.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 6efed2fa1..b6e12dfcb 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -2,6 +2,7 @@ use ndarray::{Array, Dimension}; use once_cell::sync::Lazy; use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel}; +use self::assert_send::AssertSend; use crate::{ devices::SupportedDevices, error::ErrorRepr, @@ -11,8 +12,6 @@ use crate::{ }, }; -pub(crate) use self::assert_send::AssertSend; - #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub(crate) enum Onnxruntime {} From 55b04d3578ce4b7f245bd6a182e289ce1771dcdd Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 6 Nov 2023 05:00:58 +0900 Subject: [PATCH 05/47] =?UTF-8?q?`InferenceModels`=E3=81=AE=E5=AE=9A?= =?UTF-8?q?=E7=BE=A9=E3=82=92`signatures`=E3=81=AB=E7=A7=BB=E5=8B=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer/signatures.rs | 6 ++++++ crates/voicevox_core/src/status.rs | 11 ++++------- crates/voicevox_core/src/voice_model.rs | 18 ++++++------------ 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index 764d70b8d..ba8f4f580 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -4,6 +4,12 @@ use ndarray::{Array0, Array1, Array2}; use crate::infer::{InferenceRuntime, RunBuilder, Signature, TypedSession}; +pub(crate) struct ModelBytesSet { + pub(crate) predict_duration: Vec, + pub(crate) predict_intonation: Vec, + pub(crate) decode: Vec, +} + pub(crate) struct SessionSet { pub(crate) predict_duration: Arc>>, pub(crate) predict_intonation: Arc>>, diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 1e4c7305b..de51bf459 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -37,20 +37,17 @@ impl Status { let models = model.read_inference_models().await?; let predict_duration_session = self.new_session( - models.predict_duration_model(), + &models.predict_duration, &self.light_session_options, model.path(), )?; let predict_intonation_session = self.new_session( - models.predict_intonation_model(), + &models.predict_intonation, &self.light_session_options, model.path(), )?; - let decode_model = self.new_session( - models.decode_model(), - &self.heavy_session_options, - model.path(), - )?; + let decode_model = + self.new_session(&models.decode, &self.heavy_session_options, model.path())?; self.loaded_models.lock().unwrap().insert( model, diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 45e7dad17..3b94ee8fd 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -3,6 +3,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; +use crate::infer::signatures::ModelBytesSet; use std::{ collections::{BTreeMap, HashMap}, io, @@ -35,15 +36,8 @@ pub struct VoiceModel { path: PathBuf, } -#[derive(Getters)] -pub(crate) struct InferenceModels { - decode_model: Vec, - predict_duration_model: Vec, - predict_intonation_model: Vec, -} - impl VoiceModel { - pub(crate) async fn read_inference_models(&self) -> LoadModelResult { + pub(crate) async fn read_inference_models(&self) -> LoadModelResult { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( @@ -53,10 +47,10 @@ impl VoiceModel { ) .await; - Ok(InferenceModels { - predict_duration_model: predict_duration_model_result?, - predict_intonation_model: predict_intonation_model_result?, - decode_model: decode_model_result?, + Ok(ModelBytesSet { + predict_duration: predict_duration_model_result?, + predict_intonation: predict_intonation_model_result?, + decode: decode_model_result?, }) } /// VVMファイルから`VoiceModel`をコンストラクトする。 From c38ad991af95ce2b26ca5f0725852f91c09f6db2 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 6 Nov 2023 22:10:43 +0900 Subject: [PATCH 06/47] =?UTF-8?q?`ErrorRepr::GetSupportedDevices`=E3=81=AE?= =?UTF-8?q?=E4=B8=AD=E8=BA=AB=E3=82=92`anyhow::Error`=E3=81=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/error.rs | 3 +-- crates/voicevox_core/src/infer/runtimes/onnxruntime.rs | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/voicevox_core/src/error.rs b/crates/voicevox_core/src/error.rs index 44451ece5..043b51991 100644 --- a/crates/voicevox_core/src/error.rs +++ b/crates/voicevox_core/src/error.rs @@ -2,7 +2,6 @@ use self::engine::{FullContextLabelError, KanaParseError}; use super::*; //use engine:: use duplicate::duplicate_item; -use onnxruntime::OrtError; use std::path::PathBuf; use thiserror::Error; use uuid::Uuid; @@ -65,7 +64,7 @@ pub(crate) enum ErrorRepr { LoadModel(#[from] LoadModelError), #[error("サポートされているデバイス情報取得中にエラーが発生しました")] - GetSupportedDevices(#[source] OrtError), + GetSupportedDevices(#[source] anyhow::Error), #[error( "`{style_id}`に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\ diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index b6e12dfcb..636efd91a 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -23,6 +23,7 @@ impl InferenceRuntime for Onnxruntime { let mut cuda_support = false; let mut dml_support = false; for provider in onnxruntime::session::get_available_providers() + .map_err(Into::into) .map_err(ErrorRepr::GetSupportedDevices)? .iter() { From 192417f99771a4117a96c4597194fee89a4a6f7b Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 7 Nov 2023 02:13:19 +0900 Subject: [PATCH 07/47] =?UTF-8?q?enum-map=20v3.0.0-beta.1=E3=82=92?= =?UTF-8?q?=E5=B0=8E=E5=85=A5=E3=81=97=E3=80=81`EnumMap`=E9=A7=86=E5=8B=95?= =?UTF-8?q?=E3=81=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 21 +++++ crates/voicevox_core/Cargo.toml | 1 + crates/voicevox_core/src/infer.rs | 62 ++++++++---- .../src/{status => infer}/model_file.rs | 0 crates/voicevox_core/src/infer/signatures.rs | 44 +++------ crates/voicevox_core/src/status.rs | 94 ++++++------------- crates/voicevox_core/src/voice_model.rs | 17 ++-- 7 files changed, 117 insertions(+), 122 deletions(-) rename crates/voicevox_core/src/{status => infer}/model_file.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index eb7c9cb65..1ab3e37f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1238,6 +1238,26 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-map" +version = "3.0.0-beta.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e698c4fb1d30d2aeaf3b169ca72fbc019a049d7c85acc7f91d5f58a22e3ee13" +dependencies = [ + "enum-map-derive", +] + +[[package]] +name = "enum-map-derive" +version = "1.0.0-0.gat.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c69b3965971f5d0ea6a6dd26b55cdd517ae0e1425dc8d94e482a5915bd7ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "enum-ordinalize" version = "3.1.15" @@ -4299,6 +4319,7 @@ dependencies = [ "duplicate", "easy-ext", "educe", + "enum-map", "fs-err", "futures", "heck", diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index ecaa495e7..cb8d88d3c 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -18,6 +18,7 @@ derive_more.workspace = true duplicate = "1.0.0" easy-ext.workspace = true educe = "0.4.23" +enum-map = "3.0.0-beta.1" fs-err.workspace = true futures.workspace = true indexmap = { version = "2.0.0", features = ["serde"] } diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 92bfc9d5e..0bc74b6a9 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -1,13 +1,15 @@ +mod model_file; pub(crate) mod runtimes; pub(crate) mod signatures; use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc}; use derive_new::new; +use enum_map::{Enum, EnumMap}; use ndarray::{Array, Dimension, LinalgScalar}; use thiserror::Error; -use crate::SupportedDevices; +use crate::{ErrorRepr, SupportedDevices}; pub(crate) trait InferenceRuntime: Copy + Ord + Hash + Debug + 'static { type Session: Session; @@ -35,11 +37,9 @@ impl InputScalar for i64 {} impl InputScalar for f32 {} pub(crate) trait Signature: Sized + Send + Sync + 'static { - type SessionSet; + type Kind: Enum; type Output; - fn get_session( - session_set: &Self::SessionSet, - ) -> &Arc>>; + const KIND: Self::Kind; fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>); } @@ -47,30 +47,52 @@ pub(crate) trait Output: Sized + Send { fn run(ctx: R::RunBuilder<'_>) -> anyhow::Result; } -pub(crate) struct TypedSession { - inner: R::Session, - marker: PhantomData, -} +pub(crate) struct SessionSet( + EnumMap>>, +); -impl TypedSession { +impl SessionSet { pub(crate) fn new( - model: impl FnOnce() -> std::result::Result, DecryptModelError>, - options: SessionOptions, + model_bytes: &EnumMap>, + mut options: impl FnMut(K) -> SessionOptions, ) -> anyhow::Result { - let inner = R::Session::new(model, options)?; - Ok(Self { - inner, + let mut sessions = model_bytes + .iter() + .map(|(k, m)| { + let sess = R::Session::new(|| model_file::decrypt(m), options(k))?; + Ok(Some(Arc::new(std::sync::Mutex::new(sess)))) + }) + .collect::>>()?; + + Ok(Self(EnumMap::::from_fn(|k| { + sessions[k.into_usize()].take().expect("should exist") + }))) + } +} + +impl SessionSet { + pub(crate) fn get>(&self) -> SessionCell { + SessionCell { + inner: self.0[S::KIND].clone(), marker: PhantomData, - }) + } } +} + +pub(crate) struct SessionCell { + inner: Arc>, + marker: PhantomData, +} - pub(crate) fn run(&mut self, sig: S) -> anyhow::Result +impl SessionCell { + pub(crate) fn run(self, input: S) -> crate::Result where S::Output: Output, { - let mut ctx = R::RunBuilder::from(&mut self.inner); - sig.input(&mut ctx); - S::Output::run(ctx) + let mut inner = self.inner.lock().unwrap(); + let mut ctx = R::RunBuilder::from(&mut inner); + input.input(&mut ctx); + S::Output::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into()) } } diff --git a/crates/voicevox_core/src/status/model_file.rs b/crates/voicevox_core/src/infer/model_file.rs similarity index 100% rename from crates/voicevox_core/src/status/model_file.rs rename to crates/voicevox_core/src/infer/model_file.rs diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index ba8f4f580..0998e8453 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,19 +1,13 @@ -use std::sync::Arc; - +use enum_map::Enum; use ndarray::{Array0, Array1, Array2}; -use crate::infer::{InferenceRuntime, RunBuilder, Signature, TypedSession}; - -pub(crate) struct ModelBytesSet { - pub(crate) predict_duration: Vec, - pub(crate) predict_intonation: Vec, - pub(crate) decode: Vec, -} +use crate::infer::{RunBuilder, Signature}; -pub(crate) struct SessionSet { - pub(crate) predict_duration: Arc>>, - pub(crate) predict_intonation: Arc>>, - pub(crate) decode: Arc>>, +#[derive(Clone, Copy, Enum)] +pub(crate) enum SignatureKind { + PredictDuration, + PredictIntonation, + Decode, } pub(crate) struct PredictDuration { @@ -22,14 +16,10 @@ pub(crate) struct PredictDuration { } impl Signature for PredictDuration { - type SessionSet = SessionSet; + type Kind = SignatureKind; type Output = (Vec,); - fn get_session( - session_set: &Self::SessionSet, - ) -> &Arc>> { - &session_set.predict_duration - } + const KIND: Self::Kind = SignatureKind::PredictDuration; fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { ctx.input(self.phoneme).input(self.speaker_id); @@ -48,14 +38,10 @@ pub(crate) struct PredictIntonation { } impl Signature for PredictIntonation { - type SessionSet = SessionSet; + type Kind = SignatureKind; type Output = (Vec,); - fn get_session( - session_set: &Self::SessionSet, - ) -> &Arc>> { - &session_set.predict_intonation - } + const KIND: Self::Kind = SignatureKind::PredictIntonation; fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { ctx.input(self.length) @@ -76,14 +62,10 @@ pub(crate) struct Decode { } impl Signature for Decode { - type SessionSet = SessionSet; + type Kind = SignatureKind; type Output = (Vec,); - fn get_session( - session_set: &Self::SessionSet, - ) -> &Arc>> { - &session_set.decode - } + const KIND: Self::Kind = SignatureKind::Decode; fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { ctx.input(self.f0) diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index de51bf459..96f5dcc0a 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,15 +1,10 @@ use super::*; use crate::infer::{ - signatures::{Decode, PredictDuration, PredictIntonation, SessionSet}, - DecryptModelError, InferenceRuntime, Output, SessionOptions, Signature, TypedSession, + signatures::SignatureKind, InferenceRuntime, Output, SessionOptions, SessionSet, Signature, }; use derive_more::Index; use educe::Educe; use itertools::iproduct; -use std::path::Path; -use std::sync::Arc; - -mod model_file; use std::collections::BTreeMap; @@ -34,27 +29,24 @@ impl Status { .unwrap() .ensure_acceptable(model)?; - let models = model.read_inference_models().await?; - - let predict_duration_session = self.new_session( - &models.predict_duration, - &self.light_session_options, - model.path(), - )?; - let predict_intonation_session = self.new_session( - &models.predict_intonation, - &self.light_session_options, - model.path(), - )?; - let decode_model = - self.new_session(&models.decode, &self.heavy_session_options, model.path())?; - - self.loaded_models.lock().unwrap().insert( - model, - predict_duration_session, - predict_intonation_session, - decode_model, - )?; + let model_bytes = &model.read_inference_models().await?; + + let session_set = SessionSet::new(model_bytes, |kind| match kind { + SignatureKind::PredictDuration | SignatureKind::PredictIntonation => { + self.light_session_options + } + SignatureKind::Decode => self.heavy_session_options, + }) + .map_err(|source| LoadModelError { + path: model.path().clone(), + context: LoadModelErrorKind::InvalidModelData, + source: Some(source), + })?; + + self.loaded_models + .lock() + .unwrap() + .insert(model, session_set)?; Ok(()) } @@ -81,21 +73,6 @@ impl Status { self.loaded_models.lock().unwrap().contains_style(style_id) } - fn new_session( - &self, - model: &[u8], - session_options: &SessionOptions, - path: impl AsRef, - ) -> LoadModelResult> { - TypedSession::::new(|| model_file::decrypt(model), *session_options).map_err( - |source| LoadModelError { - path: path.as_ref().to_owned(), - context: LoadModelErrorKind::InvalidModelData, - source: Some(source), - }, - ) - } - pub fn validate_speaker_id(&self, style_id: StyleId) -> bool { self.is_loaded_model_by_style_id(style_id) } @@ -109,21 +86,16 @@ impl Status { input: S, ) -> Result where - S: Signature, - for<'a> &'a S::SessionSet: From<&'a SessionSet>, + S: Signature, S::Output: Output, { - let sess = - S::get_session::((&self.loaded_models.lock().unwrap()[model_id].session_set).into()) - .clone(); - - tokio::task::spawn_blocking(move || { - let mut sess = sess.lock().unwrap(); - sess.run(input) - .map_err(|e| ErrorRepr::InferenceFailed(e).into()) - }) - .await - .unwrap() + let sess = self.loaded_models.lock().unwrap()[model_id] + .session_set + .get(); + + tokio::task::spawn_blocking(move || sess.run(input)) + .await + .unwrap() } } @@ -137,7 +109,7 @@ struct LoadedModels(BTreeMap>) struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: SessionSet, + session_set: SessionSet, } impl LoadedModels { @@ -214,9 +186,7 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - predict_duration: TypedSession, - predict_intonation: TypedSession, - decode: TypedSession, + session_set: SessionSet, ) -> Result<()> { self.ensure_acceptable(model)?; @@ -225,11 +195,7 @@ impl LoadedModels { LoadedModel { model_inner_ids: model.model_inner_ids(), metas: model.metas().clone(), - session_set: SessionSet { - predict_duration: Arc::new(std::sync::Mutex::new(predict_duration)), - predict_intonation: Arc::new(std::sync::Mutex::new(predict_intonation)), - decode: Arc::new(std::sync::Mutex::new(decode)), - }, + session_set, }, ); assert!(prev.is_none()); diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 3b94ee8fd..a153b268f 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -1,9 +1,10 @@ use async_zip::{read::fs::ZipFileReader, ZipEntry}; +use enum_map::EnumMap; use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::signatures::ModelBytesSet; +use crate::infer::signatures::SignatureKind; use std::{ collections::{BTreeMap, HashMap}, io, @@ -37,7 +38,9 @@ pub struct VoiceModel { } impl VoiceModel { - pub(crate) async fn read_inference_models(&self) -> LoadModelResult { + pub(crate) async fn read_inference_models( + &self, + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( @@ -47,11 +50,11 @@ impl VoiceModel { ) .await; - Ok(ModelBytesSet { - predict_duration: predict_duration_model_result?, - predict_intonation: predict_intonation_model_result?, - decode: decode_model_result?, - }) + Ok(EnumMap::from_array([ + predict_duration_model_result?, + predict_intonation_model_result?, + decode_model_result?, + ])) } /// VVMファイルから`VoiceModel`をコンストラクトする。 pub async fn from_path(path: impl AsRef) -> Result { From 20db67a881d43e550f3bd2f765c7262169ac0eec Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 7 Nov 2023 02:33:07 +0900 Subject: [PATCH 08/47] Minor refactor --- .../voicevox_core/src/engine/synthesis_engine.rs | 6 +++--- crates/voicevox_core/src/infer.rs | 14 ++++++-------- .../src/infer/runtimes/onnxruntime.rs | 2 +- crates/voicevox_core/src/status.rs | 8 ++++---- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index 8db50b604..fc91005c4 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -633,12 +633,12 @@ mod tests { use ::test_util::OPEN_JTALK_DIC_DIR; use pretty_assertions::assert_eq; - use crate::{infer::runtimes::Onnxruntime, *}; + use crate::{synthesizer::InferenceRuntimeImpl, *}; #[rstest] #[tokio::test] async fn is_openjtalk_dict_loaded_works() { - let core = InferenceCore::::new(false, 0).unwrap(); + let core = InferenceCore::::new(false, 0).unwrap(); let synthesis_engine = SynthesisEngine::new(core, OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap().into()); @@ -648,7 +648,7 @@ mod tests { #[rstest] #[tokio::test] async fn create_accent_phrases_works() { - let core = InferenceCore::::new(false, 0).unwrap(); + let core = InferenceCore::::new(false, 0).unwrap(); let model = &VoiceModel::sample().await.unwrap(); core.load_model(model).await.unwrap(); diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 0bc74b6a9..d5a55cfea 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -2,7 +2,7 @@ mod model_file; pub(crate) mod runtimes; pub(crate) mod signatures; -use std::{fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc}; +use std::{fmt::Debug, marker::PhantomData, sync::Arc}; use derive_new::new; use enum_map::{Enum, EnumMap}; @@ -11,9 +11,9 @@ use thiserror::Error; use crate::{ErrorRepr, SupportedDevices}; -pub(crate) trait InferenceRuntime: Copy + Ord + Hash + Debug + 'static { +pub(crate) trait InferenceRuntime: 'static { type Session: Session; - type RunBuilder<'a>: RunBuilder<'a, Runtime = Self>; + type RunBuilder<'a>: RunBuilder<'a, Session = Self::Session>; fn supported_devices() -> crate::Result; } @@ -24,10 +24,8 @@ pub(crate) trait Session: Sized + Send + 'static { ) -> anyhow::Result; } -pub(crate) trait RunBuilder<'a>: - From<&'a mut ::Session> -{ - type Runtime: InferenceRuntime; +pub(crate) trait RunBuilder<'a>: From<&'a mut Self::Session> { + type Session: Session; fn input(&mut self, tensor: Array) -> &mut Self; } @@ -36,7 +34,7 @@ pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputSca impl InputScalar for i64 {} impl InputScalar for f32 {} -pub(crate) trait Signature: Sized + Send + Sync + 'static { +pub(crate) trait Signature: Sized + Send + 'static { type Kind: Enum; type Output; const KIND: Self::Kind; diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 636efd91a..ebaa2bcda 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -107,7 +107,7 @@ impl<'sess> From<&'sess mut AssertSend>> } impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> { - type Runtime = Onnxruntime; + type Session = AssertSend>; fn input(&mut self, tensor: Array) -> &mut Self { self.inputs diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 96f5dcc0a..a4cd4ee84 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -224,8 +224,8 @@ impl LoadedModels { mod tests { use super::*; - use crate::infer::runtimes::Onnxruntime; use crate::macros::tests::assert_debug_fmt_eq; + use crate::synthesizer::InferenceRuntimeImpl; use pretty_assertions::assert_eq; #[rstest] @@ -237,7 +237,7 @@ mod tests { #[case(false, 8)] #[case(false, 0)] fn status_new_works(#[case] use_gpu: bool, #[case] cpu_num_threads: u16) { - let status = Status::::new(use_gpu, cpu_num_threads); + let status = Status::::new(use_gpu, cpu_num_threads); assert_eq!(false, status.light_session_options.use_gpu); assert_eq!(use_gpu, status.heavy_session_options.use_gpu); assert_eq!( @@ -254,7 +254,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_load_model_works() { - let status = Status::::new(false, 0); + let status = Status::::new(false, 0); let result = status.load_model(&open_default_vvm_file().await).await; assert_debug_fmt_eq!(Ok(()), result); assert_eq!(1, status.loaded_models.lock().unwrap().0.len()); @@ -263,7 +263,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_is_model_loaded_works() { - let status = Status::::new(false, 0); + let status = Status::::new(false, 0); let vvm = open_default_vvm_file().await; assert!( !status.is_loaded_model(vvm.id()), From cc84068d34e455f87f3fe79eea7cbf5db6909c27 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 7 Nov 2023 05:10:46 +0900 Subject: [PATCH 09/47] Minor refactor --- crates/voicevox_core/src/infer.rs | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index d5a55cfea..1fccadf0a 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -2,7 +2,7 @@ mod model_file; pub(crate) mod runtimes; pub(crate) mod signatures; -use std::{fmt::Debug, marker::PhantomData, sync::Arc}; +use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; use derive_new::new; use enum_map::{Enum, EnumMap}; @@ -29,13 +29,13 @@ pub(crate) trait RunBuilder<'a>: From<&'a mut Self::Session> { fn input(&mut self, tensor: Array) -> &mut Self; } -pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::OnnxruntimeInputScalar {} +pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::InputScalar {} impl InputScalar for i64 {} impl InputScalar for f32 {} pub(crate) trait Signature: Sized + Send + 'static { - type Kind: Enum; + type Kind: Enum + Copy; type Output; const KIND: Self::Kind; fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>); @@ -49,7 +49,7 @@ pub(crate) struct SessionSet( EnumMap>>, ); -impl SessionSet { +impl SessionSet { pub(crate) fn new( model_bytes: &EnumMap>, mut options: impl FnMut(K) -> SessionOptions, @@ -58,12 +58,12 @@ impl SessionSet { .iter() .map(|(k, m)| { let sess = R::Session::new(|| model_file::decrypt(m), options(k))?; - Ok(Some(Arc::new(std::sync::Mutex::new(sess)))) + Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) }) - .collect::>>()?; + .collect::>>()?; Ok(Self(EnumMap::::from_fn(|k| { - sessions[k.into_usize()].take().expect("should exist") + sessions.remove(&k.into_usize()).expect("should exist") }))) } } @@ -105,11 +105,15 @@ pub(crate) struct SessionOptions { pub(crate) struct DecryptModelError; mod sealed { + pub(crate) trait InputScalar: OnnxruntimeInputScalar {} + + impl InputScalar for i64 {} + impl InputScalar for f32 {} + pub(crate) trait OnnxruntimeInputScalar: onnxruntime::TypeToTensorElementDataType { } - impl OnnxruntimeInputScalar for i64 {} - impl OnnxruntimeInputScalar for f32 {} + impl OnnxruntimeInputScalar for T {} } From e0f29c649df39a4006a47c8813208ac9232cb9d1 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 03:11:56 +0900 Subject: [PATCH 10/47] =?UTF-8?q?=E8=89=B2=E3=80=85=E5=86=8D=E6=A7=8B?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/engine/synthesis_engine.rs | 20 +-- crates/voicevox_core/src/infer.rs | 108 +++++++++------- .../src/infer/runtimes/onnxruntime.rs | 31 +++-- crates/voicevox_core/src/infer/signatures.rs | 117 +++++++++++++----- crates/voicevox_core/src/inference_core.rs | 22 ++-- crates/voicevox_core/src/status.rs | 30 +++-- crates/voicevox_core/src/synthesizer.rs | 7 +- 7 files changed, 209 insertions(+), 126 deletions(-) diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index fc91005c4..b171f978d 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -5,7 +5,10 @@ use std::sync::Arc; use super::full_context_label::Utterance; use super::open_jtalk::OpenJtalk; use super::*; -use crate::infer::{InferenceRuntime, Output}; +use crate::infer::{ + signatures::{Decode, PredictDuration, PredictIntonation}, + InferenceRuntime, SupportsInferenceSignature, +}; use crate::numerics::F32Ext as _; use crate::InferenceCore; @@ -15,19 +18,20 @@ const MORA_PHONEME_LIST: &[&str] = &[ "a", "i", "u", "e", "o", "N", "A", "I", "U", "E", "O", "cl", "pau", ]; +pub const DEFAULT_SAMPLING_RATE: u32 = 24000; + #[derive(new)] pub(crate) struct SynthesisEngine { inference_core: InferenceCore, open_jtalk: Arc, } -impl SynthesisEngine -where - R: InferenceRuntime, - (Vec,): Output, +impl< + R: SupportsInferenceSignature + + SupportsInferenceSignature + + SupportsInferenceSignature, + > SynthesisEngine { - pub const DEFAULT_SAMPLING_RATE: u32 = 24000; - pub fn inference_core(&self) -> &InferenceCore { &self.inference_core } @@ -426,7 +430,7 @@ where let num_channels: u16 = if output_stereo { 2 } else { 1 }; let bit_depth: u16 = 16; let repeat_count: u32 = - (output_sampling_rate / Self::DEFAULT_SAMPLING_RATE) * num_channels as u32; + (output_sampling_rate / DEFAULT_SAMPLING_RATE) * num_channels as u32; let block_size: u16 = bit_depth * num_channels / 8; let bytes_size = wave.len() as u32 * repeat_count * 2; diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 1fccadf0a..d10adb8bc 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -6,53 +6,70 @@ use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; use derive_new::new; use enum_map::{Enum, EnumMap}; -use ndarray::{Array, Dimension, LinalgScalar}; use thiserror::Error; use crate::{ErrorRepr, SupportedDevices}; pub(crate) trait InferenceRuntime: 'static { - type Session: Session; - type RunBuilder<'a>: RunBuilder<'a, Session = Self::Session>; + type Session: InferenceSession; + type RunContext<'a>: RunContext<'a, Session = Self::Session>; fn supported_devices() -> crate::Result; } -pub(crate) trait Session: Sized + Send + 'static { +pub(crate) trait InferenceSession: Sized + Send + 'static { fn new( model: impl FnOnce() -> std::result::Result, DecryptModelError>, - options: SessionOptions, + options: InferenceSessionOptions, ) -> anyhow::Result; } -pub(crate) trait RunBuilder<'a>: From<&'a mut Self::Session> { - type Session: Session; - fn input(&mut self, tensor: Array) -> &mut Self; +pub(crate) trait RunContext<'a>: From<&'a mut Self::Session> { + type Session: InferenceSession; } -pub(crate) trait InputScalar: LinalgScalar + Debug + sealed::InputScalar {} +pub(crate) trait SupportsInferenceSignature: + SupportsInferenceInputTensors + SupportsInferenceOutput +{ +} + +impl< + R: SupportsInferenceInputTensors + SupportsInferenceOutput, + S: InferenceSignature, + > SupportsInferenceSignature for R +{ +} + +pub(crate) trait SupportsInferenceInputTensor: InferenceRuntime { + fn input(ctx: &mut Self::RunContext<'_>, tensor: I); +} -impl InputScalar for i64 {} -impl InputScalar for f32 {} +pub(crate) trait SupportsInferenceInputTensors: InferenceRuntime { + fn input(ctx: &mut Self::RunContext<'_>, tensors: I); +} + +pub(crate) trait SupportsInferenceOutput: InferenceRuntime { + fn run(ctx: Self::RunContext<'_>) -> anyhow::Result; +} -pub(crate) trait Signature: Sized + Send + 'static { +pub(crate) trait InferenceSignature: Sized + Send + 'static { type Kind: Enum + Copy; - type Output; + type Input: InferenceInput; + type Output: Send; const KIND: Self::Kind; - fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>); } -pub(crate) trait Output: Sized + Send { - fn run(ctx: R::RunBuilder<'_>) -> anyhow::Result; +pub(crate) trait InferenceInput: Send + 'static { + type Signature: InferenceSignature; } -pub(crate) struct SessionSet( +pub(crate) struct InferenceSessionSet( EnumMap>>, ); -impl SessionSet { +impl InferenceSessionSet { pub(crate) fn new( model_bytes: &EnumMap>, - mut options: impl FnMut(K) -> SessionOptions, + mut options: impl FnMut(K) -> InferenceSessionOptions, ) -> anyhow::Result { let mut sessions = model_bytes .iter() @@ -68,34 +85,43 @@ impl SessionSet { } } -impl SessionSet { - pub(crate) fn get>(&self) -> SessionCell { - SessionCell { - inner: self.0[S::KIND].clone(), +impl InferenceSessionSet { + pub(crate) fn get(&self) -> InferenceSessionCell + where + I: InferenceInput, + I::Signature: InferenceSignature, + { + InferenceSessionCell { + inner: self.0[::KIND].clone(), marker: PhantomData, } } } -pub(crate) struct SessionCell { +pub(crate) struct InferenceSessionCell { inner: Arc>, - marker: PhantomData, + marker: PhantomData, } -impl SessionCell { - pub(crate) fn run(self, input: S) -> crate::Result - where - S::Output: Output, - { +impl< + R: SupportsInferenceInputTensors + + SupportsInferenceOutput<::Output>, + I: InferenceInput, + > InferenceSessionCell +{ + pub(crate) fn run( + self, + input: I, + ) -> crate::Result<::Output> { let mut inner = self.inner.lock().unwrap(); - let mut ctx = R::RunBuilder::from(&mut inner); - input.input(&mut ctx); - S::Output::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into()) + let mut ctx = R::RunContext::from(&mut inner); + R::input(&mut ctx, input); + R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into()) } } #[derive(new, Clone, Copy)] -pub(crate) struct SessionOptions { +pub(crate) struct InferenceSessionOptions { pub(crate) cpu_num_threads: u16, pub(crate) use_gpu: bool, } @@ -103,17 +129,3 @@ pub(crate) struct SessionOptions { #[derive(Error, Debug)] #[error("不正なモデルファイルです")] pub(crate) struct DecryptModelError; - -mod sealed { - pub(crate) trait InputScalar: OnnxruntimeInputScalar {} - - impl InputScalar for i64 {} - impl InputScalar for f32 {} - - pub(crate) trait OnnxruntimeInputScalar: - onnxruntime::TypeToTensorElementDataType - { - } - - impl OnnxruntimeInputScalar for T {} -} diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index ebaa2bcda..027047051 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -1,14 +1,18 @@ +use std::fmt::Debug; + use ndarray::{Array, Dimension}; use once_cell::sync::Lazy; -use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel}; +use onnxruntime::{ + environment::Environment, GraphOptimizationLevel, LoggingLevel, TypeToTensorElementDataType, +}; use self::assert_send::AssertSend; use crate::{ devices::SupportedDevices, error::ErrorRepr, infer::{ - DecryptModelError, InferenceRuntime, InputScalar, Output, RunBuilder, Session, - SessionOptions, + DecryptModelError, InferenceRuntime, InferenceSession, InferenceSessionOptions, RunContext, + SupportsInferenceInputTensor, SupportsInferenceOutput, }, }; @@ -17,7 +21,7 @@ pub(crate) enum Onnxruntime {} impl InferenceRuntime for Onnxruntime { type Session = AssertSend>; - type RunBuilder<'a> = OnnxruntimeInferenceBuilder<'a>; + type RunContext<'a> = OnnxruntimeInferenceBuilder<'a>; fn supported_devices() -> crate::Result { let mut cuda_support = false; @@ -42,10 +46,10 @@ impl InferenceRuntime for Onnxruntime { } } -impl Session for AssertSend> { +impl InferenceSession for AssertSend> { fn new( model: impl FnOnce() -> std::result::Result, DecryptModelError>, - options: SessionOptions, + options: InferenceSessionOptions, ) -> anyhow::Result { let mut builder = ENVIRONMENT .new_session_builder()? @@ -106,20 +110,23 @@ impl<'sess> From<&'sess mut AssertSend>> } } -impl<'sess> RunBuilder<'sess> for OnnxruntimeInferenceBuilder<'sess> { +impl<'sess> RunContext<'sess> for OnnxruntimeInferenceBuilder<'sess> { type Session = AssertSend>; +} - fn input(&mut self, tensor: Array) -> &mut Self { - self.inputs +impl + SupportsInferenceInputTensor> for Onnxruntime +{ + fn input(ctx: &mut Self::RunContext<'_>, tensor: Array) { + ctx.inputs .push(Box::new(onnxruntime::session::NdArray::new(tensor))); - self } } -impl Output for (Vec,) { +impl SupportsInferenceOutput<(Vec,)> for Onnxruntime { fn run( OnnxruntimeInferenceBuilder { sess, mut inputs }: OnnxruntimeInferenceBuilder<'_>, - ) -> anyhow::Result { + ) -> anyhow::Result<(Vec,)> { let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; // FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index 0998e8453..8f7f898dd 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,7 +1,25 @@ use enum_map::Enum; use ndarray::{Array0, Array1, Array2}; -use crate::infer::{RunBuilder, Signature}; +use crate::infer::{ + InferenceInput, InferenceSignature, SupportsInferenceInputTensor, + SupportsInferenceInputTensors, SupportsInferenceSignature, +}; + +pub(crate) trait SupportsAllSignatures: + SupportsInferenceSignature + + SupportsInferenceSignature + + SupportsInferenceSignature +{ +} + +impl< + R: SupportsInferenceSignature + + SupportsInferenceSignature + + SupportsInferenceSignature, + > SupportsAllSignatures for R +{ +} #[derive(Clone, Copy, Enum)] pub(crate) enum SignatureKind { @@ -10,23 +28,43 @@ pub(crate) enum SignatureKind { Decode, } -pub(crate) struct PredictDuration { - pub(crate) phoneme: Array1, - pub(crate) speaker_id: Array1, -} +pub(crate) enum PredictDuration {} -impl Signature for PredictDuration { +impl InferenceSignature for PredictDuration { type Kind = SignatureKind; + type Input = PredictDurationInput; type Output = (Vec,); - const KIND: Self::Kind = SignatureKind::PredictDuration; +} - fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { - ctx.input(self.phoneme).input(self.speaker_id); +pub(crate) struct PredictDurationInput { + pub(crate) phoneme: Array1, + pub(crate) speaker_id: Array1, +} + +impl InferenceInput for PredictDurationInput { + type Signature = PredictDuration; +} + +impl>> + SupportsInferenceInputTensors for R +{ + fn input(ctx: &mut R::RunContext<'_>, input: PredictDurationInput) { + R::input(ctx, input.phoneme); + R::input(ctx, input.speaker_id); } } -pub(crate) struct PredictIntonation { +pub(crate) enum PredictIntonation {} + +impl InferenceSignature for PredictIntonation { + type Kind = SignatureKind; + type Input = PredictIntonationInput; + type Output = (Vec,); + const KIND: Self::Kind = SignatureKind::PredictIntonation; +} + +pub(crate) struct PredictIntonationInput { pub(crate) length: Array0, pub(crate) vowel_phoneme: Array1, pub(crate) consonant_phoneme: Array1, @@ -37,39 +75,50 @@ pub(crate) struct PredictIntonation { pub(crate) speaker_id: Array1, } -impl Signature for PredictIntonation { - type Kind = SignatureKind; - type Output = (Vec,); - - const KIND: Self::Kind = SignatureKind::PredictIntonation; +impl InferenceInput for PredictIntonationInput { + type Signature = PredictIntonation; +} - fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { - ctx.input(self.length) - .input(self.vowel_phoneme) - .input(self.consonant_phoneme) - .input(self.start_accent) - .input(self.end_accent) - .input(self.start_accent_phrase) - .input(self.end_accent_phrase) - .input(self.speaker_id); +impl> + SupportsInferenceInputTensor>> + SupportsInferenceInputTensors for R +{ + fn input(ctx: &mut R::RunContext<'_>, input: PredictIntonationInput) { + R::input(ctx, input.length); + R::input(ctx, input.vowel_phoneme); + R::input(ctx, input.consonant_phoneme); + R::input(ctx, input.start_accent); + R::input(ctx, input.end_accent); + R::input(ctx, input.start_accent_phrase); + R::input(ctx, input.end_accent_phrase); + R::input(ctx, input.speaker_id); } } -pub(crate) struct Decode { +pub(crate) enum Decode {} + +impl InferenceSignature for Decode { + type Kind = SignatureKind; + type Input = DecodeInput; + type Output = (Vec,); + const KIND: Self::Kind = SignatureKind::Decode; +} + +pub(crate) struct DecodeInput { pub(crate) f0: Array2, pub(crate) phoneme: Array2, pub(crate) speaker_id: Array1, } -impl Signature for Decode { - type Kind = SignatureKind; - type Output = (Vec,); - - const KIND: Self::Kind = SignatureKind::Decode; +impl InferenceInput for DecodeInput { + type Signature = Decode; +} - fn input<'a, 'b>(self, ctx: &'a mut impl RunBuilder<'b>) { - ctx.input(self.f0) - .input(self.phoneme) - .input(self.speaker_id); +impl> + SupportsInferenceInputTensor>> + SupportsInferenceInputTensors for R +{ + fn input(ctx: &mut R::RunContext<'_>, input: DecodeInput) { + R::input(ctx, input.f0); + R::input(ctx, input.phoneme); + R::input(ctx, input.speaker_id); } } diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 413c339cc..264f56942 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,8 +1,11 @@ use self::status::*; use super::*; use crate::infer::{ - signatures::{Decode, PredictDuration, PredictIntonation}, - InferenceRuntime, Output, + signatures::{ + Decode, DecodeInput, PredictDuration, PredictDurationInput, PredictIntonation, + PredictIntonationInput, + }, + InferenceRuntime, SupportsInferenceSignature, }; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; @@ -11,10 +14,11 @@ pub(crate) struct InferenceCore { status: Status, } -impl InferenceCore -where - R: InferenceRuntime, - (Vec,): Output, +impl< + R: SupportsInferenceSignature + + SupportsInferenceSignature + + SupportsInferenceSignature, + > InferenceCore { pub(crate) fn new(use_gpu: bool, cpu_num_threads: u16) -> Result { if !use_gpu || Self::can_support_gpu_feature()? { @@ -71,7 +75,7 @@ where .status .run_session( &model_id, - PredictDuration { + PredictDurationInput { phoneme: ndarray::arr1(phoneme_vector), speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), }, @@ -109,7 +113,7 @@ where .status .run_session( &model_id, - PredictIntonation { + PredictIntonationInput { length: ndarray::arr0(length as i64), vowel_phoneme: ndarray::arr1(vowel_phoneme_vector), consonant_phoneme: ndarray::arr1(consonant_phoneme_vector), @@ -159,7 +163,7 @@ where .status .run_session( &model_id, - Decode { + DecodeInput { f0: ndarray::arr1(&f0_with_padding) .into_shape([length_with_padding, 1]) .unwrap(), diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index a4cd4ee84..8ebc5e7b3 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,6 +1,8 @@ use super::*; use crate::infer::{ - signatures::SignatureKind, InferenceRuntime, Output, SessionOptions, SessionSet, Signature, + signatures::SignatureKind, InferenceInput, InferenceRuntime, InferenceSessionOptions, + InferenceSessionSet, InferenceSignature, SupportsInferenceInputTensors, + SupportsInferenceOutput, }; use derive_more::Index; use educe::Educe; @@ -10,16 +12,16 @@ use std::collections::BTreeMap; pub(crate) struct Status { loaded_models: std::sync::Mutex>, - light_session_options: SessionOptions, // 軽いモデルはこちらを使う - heavy_session_options: SessionOptions, // 重いモデルはこちらを使う + light_session_options: InferenceSessionOptions, // 軽いモデルはこちらを使う + heavy_session_options: InferenceSessionOptions, // 重いモデルはこちらを使う } impl Status { pub fn new(use_gpu: bool, cpu_num_threads: u16) -> Self { Self { loaded_models: Default::default(), - light_session_options: SessionOptions::new(cpu_num_threads, false), - heavy_session_options: SessionOptions::new(cpu_num_threads, use_gpu), + light_session_options: InferenceSessionOptions::new(cpu_num_threads, false), + heavy_session_options: InferenceSessionOptions::new(cpu_num_threads, use_gpu), } } @@ -31,7 +33,7 @@ impl Status { let model_bytes = &model.read_inference_models().await?; - let session_set = SessionSet::new(model_bytes, |kind| match kind { + let session_set = InferenceSessionSet::new(model_bytes, |kind| match kind { SignatureKind::PredictDuration | SignatureKind::PredictIntonation => { self.light_session_options } @@ -80,14 +82,16 @@ impl Status { /// # Panics /// /// `self`が`model_id`を含んでいないとき、パニックする。 - pub(crate) async fn run_session( + pub(crate) async fn run_session( &self, model_id: &VoiceModelId, - input: S, - ) -> Result + input: I, + ) -> Result<::Output> where - S: Signature, - S::Output: Output, + I: InferenceInput, + I::Signature: InferenceSignature, + R: SupportsInferenceInputTensors + + SupportsInferenceOutput<::Output>, { let sess = self.loaded_models.lock().unwrap()[model_id] .session_set @@ -109,7 +113,7 @@ struct LoadedModels(BTreeMap>) struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: SessionSet, + session_set: InferenceSessionSet, } impl LoadedModels { @@ -186,7 +190,7 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - session_set: SessionSet, + session_set: InferenceSessionSet, ) -> Result<()> { self.ensure_acceptable(model)?; diff --git a/crates/voicevox_core/src/synthesizer.rs b/crates/voicevox_core/src/synthesizer.rs index 88f419476..594cfd856 100644 --- a/crates/voicevox_core/src/synthesizer.rs +++ b/crates/voicevox_core/src/synthesizer.rs @@ -1,7 +1,10 @@ use std::sync::Arc; use crate::{ - engine::{create_kana, parse_kana, AccentPhraseModel, OpenJtalk, SynthesisEngine}, + engine::{ + create_kana, parse_kana, AccentPhraseModel, OpenJtalk, SynthesisEngine, + DEFAULT_SAMPLING_RATE, + }, infer::runtimes::Onnxruntime, }; @@ -560,7 +563,7 @@ impl AudioQueryModel { 1., 0.1, 0.1, - SynthesisEngine::DEFAULT_SAMPLING_RATE, + DEFAULT_SAMPLING_RATE, false, Some(kana), ) From cb1db348ab43b01b542c6404e6c0283e9aae251b Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 03:44:16 +0900 Subject: [PATCH 11/47] Fix up --- crates/voicevox_core/src/infer.rs | 26 +++++++--- .../src/infer/runtimes/onnxruntime.rs | 4 +- crates/voicevox_core/src/infer/signatures.rs | 50 +++++++------------ 3 files changed, 39 insertions(+), 41 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index d10adb8bc..d12b15418 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -5,6 +5,7 @@ pub(crate) mod signatures; use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; use derive_new::new; +use easy_ext::ext; use enum_map::{Enum, EnumMap}; use thiserror::Error; @@ -12,7 +13,7 @@ use crate::{ErrorRepr, SupportedDevices}; pub(crate) trait InferenceRuntime: 'static { type Session: InferenceSession; - type RunContext<'a>: RunContext<'a, Session = Self::Session>; + type RunContext<'a>: RunContext<'a, Runtime = Self>; fn supported_devices() -> crate::Result; } @@ -23,8 +24,21 @@ pub(crate) trait InferenceSession: Sized + Send + 'static { ) -> anyhow::Result; } -pub(crate) trait RunContext<'a>: From<&'a mut Self::Session> { - type Session: InferenceSession; +pub(crate) trait RunContext<'a>: + From<&'a mut ::Session> +{ + type Runtime: InferenceRuntime = Self>; +} + +#[ext(RunContextExt)] +impl<'a, T: RunContext<'a>> T { + fn input(&mut self, tensor: I) -> &mut Self + where + T::Runtime: SupportsInferenceInputTensor, + { + >::input(tensor, self); + self + } } pub(crate) trait SupportsInferenceSignature: @@ -40,11 +54,11 @@ impl< } pub(crate) trait SupportsInferenceInputTensor: InferenceRuntime { - fn input(ctx: &mut Self::RunContext<'_>, tensor: I); + fn input(tensor: I, ctx: &mut Self::RunContext<'_>); } pub(crate) trait SupportsInferenceInputTensors: InferenceRuntime { - fn input(ctx: &mut Self::RunContext<'_>, tensors: I); + fn input(tensors: I, ctx: &mut Self::RunContext<'_>); } pub(crate) trait SupportsInferenceOutput: InferenceRuntime { @@ -115,7 +129,7 @@ impl< ) -> crate::Result<::Output> { let mut inner = self.inner.lock().unwrap(); let mut ctx = R::RunContext::from(&mut inner); - R::input(&mut ctx, input); + R::input(input, &mut ctx); R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into()) } } diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 027047051..e75c3beb1 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -111,13 +111,13 @@ impl<'sess> From<&'sess mut AssertSend>> } impl<'sess> RunContext<'sess> for OnnxruntimeInferenceBuilder<'sess> { - type Session = AssertSend>; + type Runtime = Onnxruntime; } impl SupportsInferenceInputTensor> for Onnxruntime { - fn input(ctx: &mut Self::RunContext<'_>, tensor: Array) { + fn input(tensor: Array, ctx: &mut Self::RunContext<'_>) { ctx.inputs .push(Box::new(onnxruntime::session::NdArray::new(tensor))); } diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index 8f7f898dd..ef9ac47b9 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -2,25 +2,10 @@ use enum_map::Enum; use ndarray::{Array0, Array1, Array2}; use crate::infer::{ - InferenceInput, InferenceSignature, SupportsInferenceInputTensor, - SupportsInferenceInputTensors, SupportsInferenceSignature, + InferenceInput, InferenceSignature, RunContextExt as _, SupportsInferenceInputTensor, + SupportsInferenceInputTensors, }; -pub(crate) trait SupportsAllSignatures: - SupportsInferenceSignature - + SupportsInferenceSignature - + SupportsInferenceSignature -{ -} - -impl< - R: SupportsInferenceSignature - + SupportsInferenceSignature - + SupportsInferenceSignature, - > SupportsAllSignatures for R -{ -} - #[derive(Clone, Copy, Enum)] pub(crate) enum SignatureKind { PredictDuration, @@ -49,9 +34,8 @@ impl InferenceInput for PredictDurationInput { impl>> SupportsInferenceInputTensors for R { - fn input(ctx: &mut R::RunContext<'_>, input: PredictDurationInput) { - R::input(ctx, input.phoneme); - R::input(ctx, input.speaker_id); + fn input(input: PredictDurationInput, ctx: &mut R::RunContext<'_>) { + ctx.input(input.phoneme).input(input.speaker_id); } } @@ -82,15 +66,15 @@ impl InferenceInput for PredictIntonationInput { impl> + SupportsInferenceInputTensor>> SupportsInferenceInputTensors for R { - fn input(ctx: &mut R::RunContext<'_>, input: PredictIntonationInput) { - R::input(ctx, input.length); - R::input(ctx, input.vowel_phoneme); - R::input(ctx, input.consonant_phoneme); - R::input(ctx, input.start_accent); - R::input(ctx, input.end_accent); - R::input(ctx, input.start_accent_phrase); - R::input(ctx, input.end_accent_phrase); - R::input(ctx, input.speaker_id); + fn input(input: PredictIntonationInput, ctx: &mut R::RunContext<'_>) { + ctx.input(input.length) + .input(input.vowel_phoneme) + .input(input.consonant_phoneme) + .input(input.start_accent) + .input(input.end_accent) + .input(input.start_accent_phrase) + .input(input.end_accent_phrase) + .input(input.speaker_id); } } @@ -116,9 +100,9 @@ impl InferenceInput for DecodeInput { impl> + SupportsInferenceInputTensor>> SupportsInferenceInputTensors for R { - fn input(ctx: &mut R::RunContext<'_>, input: DecodeInput) { - R::input(ctx, input.f0); - R::input(ctx, input.phoneme); - R::input(ctx, input.speaker_id); + fn input(input: DecodeInput, ctx: &mut R::RunContext<'_>) { + ctx.input(input.f0) + .input(input.phoneme) + .input(input.speaker_id); } } From c3e08dd4da27b9691203f071c2bf716badf1e087 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 03:56:22 +0900 Subject: [PATCH 12/47] =?UTF-8?q?`OnnxruntimeInferenceBuilder`=20=E2=86=92?= =?UTF-8?q?=20`OnnxruntimeRunContext`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer/runtimes/onnxruntime.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index e75c3beb1..eb007fc19 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -21,7 +21,7 @@ pub(crate) enum Onnxruntime {} impl InferenceRuntime for Onnxruntime { type Session = AssertSend>; - type RunContext<'a> = OnnxruntimeInferenceBuilder<'a>; + type RunContext<'a> = OnnxruntimeRunContext<'a>; fn supported_devices() -> crate::Result { let mut cuda_support = false; @@ -94,13 +94,13 @@ impl InferenceSession for AssertSend> { } } -pub(crate) struct OnnxruntimeInferenceBuilder<'sess> { +pub(crate) struct OnnxruntimeRunContext<'sess> { sess: &'sess mut AssertSend>, inputs: Vec>, } impl<'sess> From<&'sess mut AssertSend>> - for OnnxruntimeInferenceBuilder<'sess> + for OnnxruntimeRunContext<'sess> { fn from(sess: &'sess mut AssertSend>) -> Self { Self { @@ -110,7 +110,7 @@ impl<'sess> From<&'sess mut AssertSend>> } } -impl<'sess> RunContext<'sess> for OnnxruntimeInferenceBuilder<'sess> { +impl<'sess> RunContext<'sess> for OnnxruntimeRunContext<'sess> { type Runtime = Onnxruntime; } @@ -125,7 +125,7 @@ impl impl SupportsInferenceOutput<(Vec,)> for Onnxruntime { fn run( - OnnxruntimeInferenceBuilder { sess, mut inputs }: OnnxruntimeInferenceBuilder<'_>, + OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, ) -> anyhow::Result<(Vec,)> { let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; From e4b91abbf4fa95711a9f62f72e5fcf00b0735999 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 04:02:46 +0900 Subject: [PATCH 13/47] =?UTF-8?q?`impl=20SupportsInferenceOutput<=5F>=20fo?= =?UTF-8?q?r=20Onnxruntime`=E3=82=92=E6=8B=A1=E5=BC=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer/runtimes/onnxruntime.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index eb007fc19..e0c54dd6e 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -123,10 +123,12 @@ impl } } -impl SupportsInferenceOutput<(Vec,)> for Onnxruntime { +impl SupportsInferenceOutput<(Vec,)> + for Onnxruntime +{ fn run( OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, - ) -> anyhow::Result<(Vec,)> { + ) -> anyhow::Result<(Vec,)> { let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; // FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く From 8584d2727252c1bd6971f01150386ebd3adc73e7 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 04:27:39 +0900 Subject: [PATCH 14/47] =?UTF-8?q?`SignatureKind`=20=E2=86=92=20`InferenceS?= =?UTF-8?q?ignatureKind`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer/signatures.rs | 14 +++++++------- crates/voicevox_core/src/status.rs | 12 ++++++------ crates/voicevox_core/src/voice_model.rs | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index ef9ac47b9..ed20dd23b 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -7,7 +7,7 @@ use crate::infer::{ }; #[derive(Clone, Copy, Enum)] -pub(crate) enum SignatureKind { +pub(crate) enum InferenceSignatureKind { PredictDuration, PredictIntonation, Decode, @@ -16,10 +16,10 @@ pub(crate) enum SignatureKind { pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { - type Kind = SignatureKind; + type Kind = InferenceSignatureKind; type Input = PredictDurationInput; type Output = (Vec,); - const KIND: Self::Kind = SignatureKind::PredictDuration; + const KIND: Self::Kind = InferenceSignatureKind::PredictDuration; } pub(crate) struct PredictDurationInput { @@ -42,10 +42,10 @@ impl>> pub(crate) enum PredictIntonation {} impl InferenceSignature for PredictIntonation { - type Kind = SignatureKind; + type Kind = InferenceSignatureKind; type Input = PredictIntonationInput; type Output = (Vec,); - const KIND: Self::Kind = SignatureKind::PredictIntonation; + const KIND: Self::Kind = InferenceSignatureKind::PredictIntonation; } pub(crate) struct PredictIntonationInput { @@ -81,10 +81,10 @@ impl> + SupportsInferenceInputTensor pub(crate) enum Decode {} impl InferenceSignature for Decode { - type Kind = SignatureKind; + type Kind = InferenceSignatureKind; type Input = DecodeInput; type Output = (Vec,); - const KIND: Self::Kind = SignatureKind::Decode; + const KIND: Self::Kind = InferenceSignatureKind::Decode; } pub(crate) struct DecodeInput { diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 8ebc5e7b3..fa8d599a8 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,6 +1,6 @@ use super::*; use crate::infer::{ - signatures::SignatureKind, InferenceInput, InferenceRuntime, InferenceSessionOptions, + signatures::InferenceSignatureKind, InferenceInput, InferenceRuntime, InferenceSessionOptions, InferenceSessionSet, InferenceSignature, SupportsInferenceInputTensors, SupportsInferenceOutput, }; @@ -34,10 +34,10 @@ impl Status { let model_bytes = &model.read_inference_models().await?; let session_set = InferenceSessionSet::new(model_bytes, |kind| match kind { - SignatureKind::PredictDuration | SignatureKind::PredictIntonation => { + InferenceSignatureKind::PredictDuration | InferenceSignatureKind::PredictIntonation => { self.light_session_options } - SignatureKind::Decode => self.heavy_session_options, + InferenceSignatureKind::Decode => self.heavy_session_options, }) .map_err(|source| LoadModelError { path: model.path().clone(), @@ -89,7 +89,7 @@ impl Status { ) -> Result<::Output> where I: InferenceInput, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, R: SupportsInferenceInputTensors + SupportsInferenceOutput<::Output>, { @@ -113,7 +113,7 @@ struct LoadedModels(BTreeMap>) struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: InferenceSessionSet, + session_set: InferenceSessionSet, } impl LoadedModels { @@ -190,7 +190,7 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - session_set: InferenceSessionSet, + session_set: InferenceSessionSet, ) -> Result<()> { self.ensure_acceptable(model)?; diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index a153b268f..1be2b6a8b 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -4,7 +4,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::signatures::SignatureKind; +use crate::infer::signatures::InferenceSignatureKind; use std::{ collections::{BTreeMap, HashMap}, io, @@ -40,7 +40,7 @@ pub struct VoiceModel { impl VoiceModel { pub(crate) async fn read_inference_models( &self, - ) -> LoadModelResult>> { + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( From 47953098a19c70f252e210dcff3e8b9e1843f1a1 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 04:32:37 +0900 Subject: [PATCH 15/47] =?UTF-8?q?`LoadedModels`=E3=81=B8=E3=81=AE=E3=82=A2?= =?UTF-8?q?=E3=82=AF=E3=82=BB=E3=82=B9=E3=82=92=E3=83=A1=E3=82=BD=E3=83=83?= =?UTF-8?q?=E3=83=89=E8=B6=8A=E3=81=97=E3=81=AB=E3=81=99=E3=82=8B=E3=81=AE?= =?UTF-8?q?=E3=82=92=E5=BE=B9=E5=BA=95=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/status.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index fa8d599a8..ef27db633 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,10 +1,9 @@ use super::*; use crate::infer::{ - signatures::InferenceSignatureKind, InferenceInput, InferenceRuntime, InferenceSessionOptions, - InferenceSessionSet, InferenceSignature, SupportsInferenceInputTensors, - SupportsInferenceOutput, + signatures::InferenceSignatureKind, InferenceInput, InferenceRuntime, InferenceSessionCell, + InferenceSessionOptions, InferenceSessionSet, InferenceSignature, + SupportsInferenceInputTensors, SupportsInferenceOutput, }; -use derive_more::Index; use educe::Educe; use itertools::iproduct; @@ -93,9 +92,7 @@ impl Status { R: SupportsInferenceInputTensors + SupportsInferenceOutput<::Output>, { - let sess = self.loaded_models.lock().unwrap()[model_id] - .session_set - .get(); + let sess = self.loaded_models.lock().unwrap().get(model_id); tokio::task::spawn_blocking(move || sess.run(input)) .await @@ -106,7 +103,7 @@ impl Status { /// 読み込んだモデルの`Session`とそのメタ情報を保有し、追加/削除/取得の操作を提供する。 /// /// この構造体のメソッドは、すべて一瞬で完了すべきである。 -#[derive(Educe, Index)] +#[derive(Educe)] #[educe(Default(bound = "R: InferenceRuntime"))] struct LoadedModels(BTreeMap>); @@ -149,6 +146,17 @@ impl LoadedModels { Ok((model_id.clone(), model_inner_id)) } + /// # Panics + /// + /// `self`が`model_id`を含んでいないとき、パニックする。 + fn get(&self, model_id: &VoiceModelId) -> InferenceSessionCell + where + I: InferenceInput, + I::Signature: InferenceSignature, + { + self.0[model_id].session_set.get() + } + fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool { self.0.contains_key(model_id) } From a5dbbddaa3109d7c859f30feadc2fd89900c41d2 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 05:04:53 +0900 Subject: [PATCH 16/47] Minor refactor --- crates/voicevox_core/src/infer.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index d12b15418..7bd39d75f 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -36,7 +36,7 @@ impl<'a, T: RunContext<'a>> T { where T::Runtime: SupportsInferenceInputTensor, { - >::input(tensor, self); + T::Runtime::input(tensor, self); self } } @@ -106,7 +106,7 @@ impl InferenceSessionSet { I::Signature: InferenceSignature, { InferenceSessionCell { - inner: self.0[::KIND].clone(), + inner: self.0[I::Signature::KIND].clone(), marker: PhantomData, } } @@ -127,8 +127,8 @@ impl< self, input: I, ) -> crate::Result<::Output> { - let mut inner = self.inner.lock().unwrap(); - let mut ctx = R::RunContext::from(&mut inner); + let inner = &mut *self.inner.lock().unwrap(); + let mut ctx = inner.into(); R::input(input, &mut ctx); R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into()) } From 525f4b1821fff4dc67b0469497d3e4213f5e150a Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 05:10:09 +0900 Subject: [PATCH 17/47] =?UTF-8?q?`InferenceInput`=20=E2=86=92=20`Inference?= =?UTF-8?q?InputSignature`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 22 ++++++++++--------- .../src/infer/runtimes/onnxruntime.rs | 4 ++-- crates/voicevox_core/src/infer/signatures.rs | 16 +++++++------- crates/voicevox_core/src/status.rs | 12 +++++----- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 7bd39d75f..560306de1 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -42,23 +42,25 @@ impl<'a, T: RunContext<'a>> T { } pub(crate) trait SupportsInferenceSignature: - SupportsInferenceInputTensors + SupportsInferenceOutput + SupportsInferenceInputSignature + SupportsInferenceOutput { } impl< - R: SupportsInferenceInputTensors + SupportsInferenceOutput, + R: SupportsInferenceInputSignature + SupportsInferenceOutput, S: InferenceSignature, > SupportsInferenceSignature for R { } pub(crate) trait SupportsInferenceInputTensor: InferenceRuntime { - fn input(tensor: I, ctx: &mut Self::RunContext<'_>); + fn input(input: I, ctx: &mut Self::RunContext<'_>); } -pub(crate) trait SupportsInferenceInputTensors: InferenceRuntime { - fn input(tensors: I, ctx: &mut Self::RunContext<'_>); +pub(crate) trait SupportsInferenceInputSignature: + InferenceRuntime +{ + fn input(input: I, ctx: &mut Self::RunContext<'_>); } pub(crate) trait SupportsInferenceOutput: InferenceRuntime { @@ -67,12 +69,12 @@ pub(crate) trait SupportsInferenceOutput: InferenceRuntime { pub(crate) trait InferenceSignature: Sized + Send + 'static { type Kind: Enum + Copy; - type Input: InferenceInput; + type Input: InferenceInputSignature; type Output: Send; const KIND: Self::Kind; } -pub(crate) trait InferenceInput: Send + 'static { +pub(crate) trait InferenceInputSignature: Send + 'static { type Signature: InferenceSignature; } @@ -102,7 +104,7 @@ impl InferenceSessionSet { impl InferenceSessionSet { pub(crate) fn get(&self) -> InferenceSessionCell where - I: InferenceInput, + I: InferenceInputSignature, I::Signature: InferenceSignature, { InferenceSessionCell { @@ -118,9 +120,9 @@ pub(crate) struct InferenceSessionCell { } impl< - R: SupportsInferenceInputTensors + R: SupportsInferenceInputSignature + SupportsInferenceOutput<::Output>, - I: InferenceInput, + I: InferenceInputSignature, > InferenceSessionCell { pub(crate) fn run( diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index e0c54dd6e..49979b0f6 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -117,9 +117,9 @@ impl<'sess> RunContext<'sess> for OnnxruntimeRunContext<'sess> { impl SupportsInferenceInputTensor> for Onnxruntime { - fn input(tensor: Array, ctx: &mut Self::RunContext<'_>) { + fn input(input: Array, ctx: &mut Self::RunContext<'_>) { ctx.inputs - .push(Box::new(onnxruntime::session::NdArray::new(tensor))); + .push(Box::new(onnxruntime::session::NdArray::new(input))); } } diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index ed20dd23b..375db5e28 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -2,8 +2,8 @@ use enum_map::Enum; use ndarray::{Array0, Array1, Array2}; use crate::infer::{ - InferenceInput, InferenceSignature, RunContextExt as _, SupportsInferenceInputTensor, - SupportsInferenceInputTensors, + InferenceInputSignature, InferenceSignature, RunContextExt as _, + SupportsInferenceInputSignature, SupportsInferenceInputTensor, }; #[derive(Clone, Copy, Enum)] @@ -27,12 +27,12 @@ pub(crate) struct PredictDurationInput { pub(crate) speaker_id: Array1, } -impl InferenceInput for PredictDurationInput { +impl InferenceInputSignature for PredictDurationInput { type Signature = PredictDuration; } impl>> - SupportsInferenceInputTensors for R + SupportsInferenceInputSignature for R { fn input(input: PredictDurationInput, ctx: &mut R::RunContext<'_>) { ctx.input(input.phoneme).input(input.speaker_id); @@ -59,12 +59,12 @@ pub(crate) struct PredictIntonationInput { pub(crate) speaker_id: Array1, } -impl InferenceInput for PredictIntonationInput { +impl InferenceInputSignature for PredictIntonationInput { type Signature = PredictIntonation; } impl> + SupportsInferenceInputTensor>> - SupportsInferenceInputTensors for R + SupportsInferenceInputSignature for R { fn input(input: PredictIntonationInput, ctx: &mut R::RunContext<'_>) { ctx.input(input.length) @@ -93,12 +93,12 @@ pub(crate) struct DecodeInput { pub(crate) speaker_id: Array1, } -impl InferenceInput for DecodeInput { +impl InferenceInputSignature for DecodeInput { type Signature = Decode; } impl> + SupportsInferenceInputTensor>> - SupportsInferenceInputTensors for R + SupportsInferenceInputSignature for R { fn input(input: DecodeInput, ctx: &mut R::RunContext<'_>) { ctx.input(input.f0) diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index ef27db633..3abfe401f 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,8 +1,8 @@ use super::*; use crate::infer::{ - signatures::InferenceSignatureKind, InferenceInput, InferenceRuntime, InferenceSessionCell, - InferenceSessionOptions, InferenceSessionSet, InferenceSignature, - SupportsInferenceInputTensors, SupportsInferenceOutput, + signatures::InferenceSignatureKind, InferenceInputSignature, InferenceRuntime, + InferenceSessionCell, InferenceSessionOptions, InferenceSessionSet, InferenceSignature, + SupportsInferenceInputSignature, SupportsInferenceOutput, }; use educe::Educe; use itertools::iproduct; @@ -87,9 +87,9 @@ impl Status { input: I, ) -> Result<::Output> where - I: InferenceInput, + I: InferenceInputSignature, I::Signature: InferenceSignature, - R: SupportsInferenceInputTensors + R: SupportsInferenceInputSignature + SupportsInferenceOutput<::Output>, { let sess = self.loaded_models.lock().unwrap().get(model_id); @@ -151,7 +151,7 @@ impl LoadedModels { /// `self`が`model_id`を含んでいないとき、パニックする。 fn get(&self, model_id: &VoiceModelId) -> InferenceSessionCell where - I: InferenceInput, + I: InferenceInputSignature, I::Signature: InferenceSignature, { self.0[model_id].session_set.get() From 26476f54c9108af5709116c43d93727ae0e07f53 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 07:59:04 +0900 Subject: [PATCH 18/47] =?UTF-8?q?=E7=9B=B8=E4=BA=92=E5=8F=82=E7=85=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 560306de1..deb3a994d 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -69,13 +69,13 @@ pub(crate) trait SupportsInferenceOutput: InferenceRuntime { pub(crate) trait InferenceSignature: Sized + Send + 'static { type Kind: Enum + Copy; - type Input: InferenceInputSignature; + type Input: InferenceInputSignature; type Output: Send; const KIND: Self::Kind; } pub(crate) trait InferenceInputSignature: Send + 'static { - type Signature: InferenceSignature; + type Signature: InferenceSignature; } pub(crate) struct InferenceSessionSet( From fbd7d1c8016e0e0d7bdc31b064a5e44d21b41f76 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 21:05:21 +0900 Subject: [PATCH 19/47] =?UTF-8?q?`fn=20input`=E3=81=BE=E3=82=8F=E3=82=8A?= =?UTF-8?q?=E3=82=92=E6=98=8E=E7=9E=AD=E3=81=AB=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 13 +++--- .../src/infer/runtimes/onnxruntime.rs | 2 +- crates/voicevox_core/src/infer/signatures.rs | 40 ++++++++++++------- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index deb3a994d..00ab6827a 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -32,11 +32,11 @@ pub(crate) trait RunContext<'a>: #[ext(RunContextExt)] impl<'a, T: RunContext<'a>> T { - fn input(&mut self, tensor: I) -> &mut Self + fn with_input(mut self, tensor: I) -> Self where T::Runtime: SupportsInferenceInputTensor, { - T::Runtime::input(tensor, self); + T::Runtime::push_input(tensor, &mut self); self } } @@ -54,13 +54,13 @@ impl< } pub(crate) trait SupportsInferenceInputTensor: InferenceRuntime { - fn input(input: I, ctx: &mut Self::RunContext<'_>); + fn push_input(input: I, ctx: &mut Self::RunContext<'_>); } pub(crate) trait SupportsInferenceInputSignature: InferenceRuntime { - fn input(input: I, ctx: &mut Self::RunContext<'_>); + fn make_run_context(sess: &mut Self::Session, input: I) -> Self::RunContext<'_>; } pub(crate) trait SupportsInferenceOutput: InferenceRuntime { @@ -129,9 +129,8 @@ impl< self, input: I, ) -> crate::Result<::Output> { - let inner = &mut *self.inner.lock().unwrap(); - let mut ctx = inner.into(); - R::input(input, &mut ctx); + let inner = &mut self.inner.lock().unwrap(); + let ctx = R::make_run_context(inner, input); R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into()) } } diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 49979b0f6..8848c24e3 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -117,7 +117,7 @@ impl<'sess> RunContext<'sess> for OnnxruntimeRunContext<'sess> { impl SupportsInferenceInputTensor> for Onnxruntime { - fn input(input: Array, ctx: &mut Self::RunContext<'_>) { + fn push_input(input: Array, ctx: &mut Self::RunContext<'_>) { ctx.inputs .push(Box::new(onnxruntime::session::NdArray::new(input))); } diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index 375db5e28..9d4690c6c 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -34,8 +34,13 @@ impl InferenceInputSignature for PredictDurationInput { impl>> SupportsInferenceInputSignature for R { - fn input(input: PredictDurationInput, ctx: &mut R::RunContext<'_>) { - ctx.input(input.phoneme).input(input.speaker_id); + fn make_run_context( + sess: &mut Self::Session, + input: PredictDurationInput, + ) -> Self::RunContext<'_> { + Self::RunContext::from(sess) + .with_input(input.phoneme) + .with_input(input.speaker_id) } } @@ -66,15 +71,19 @@ impl InferenceInputSignature for PredictIntonationInput { impl> + SupportsInferenceInputTensor>> SupportsInferenceInputSignature for R { - fn input(input: PredictIntonationInput, ctx: &mut R::RunContext<'_>) { - ctx.input(input.length) - .input(input.vowel_phoneme) - .input(input.consonant_phoneme) - .input(input.start_accent) - .input(input.end_accent) - .input(input.start_accent_phrase) - .input(input.end_accent_phrase) - .input(input.speaker_id); + fn make_run_context( + sess: &mut Self::Session, + input: PredictIntonationInput, + ) -> Self::RunContext<'_> { + Self::RunContext::from(sess) + .with_input(input.length) + .with_input(input.vowel_phoneme) + .with_input(input.consonant_phoneme) + .with_input(input.start_accent) + .with_input(input.end_accent) + .with_input(input.start_accent_phrase) + .with_input(input.end_accent_phrase) + .with_input(input.speaker_id) } } @@ -100,9 +109,10 @@ impl InferenceInputSignature for DecodeInput { impl> + SupportsInferenceInputTensor>> SupportsInferenceInputSignature for R { - fn input(input: DecodeInput, ctx: &mut R::RunContext<'_>) { - ctx.input(input.f0) - .input(input.phoneme) - .input(input.speaker_id); + fn make_run_context(sess: &mut Self::Session, input: DecodeInput) -> Self::RunContext<'_> { + Self::RunContext::from(sess) + .with_input(input.f0) + .with_input(input.phoneme) + .with_input(input.speaker_id) } } From 8b4f3b6f4f5be8a083eb22427c031ba53ae85cfd Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 9 Nov 2023 22:02:41 +0900 Subject: [PATCH 20/47] =?UTF-8?q?"signature"=E3=81=AEkind=E3=81=A7?= =?UTF-8?q?=E3=81=AF=E3=81=AA=E3=81=8F"model"=E3=81=AEkind=E3=81=A8?= =?UTF-8?q?=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 26 +++++++++++--------- crates/voicevox_core/src/infer/signatures.rs | 22 +++++++++++------ crates/voicevox_core/src/status.rs | 19 +++++++------- crates/voicevox_core/src/voice_model.rs | 4 +-- 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 00ab6827a..1559828cf 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -67,25 +67,29 @@ pub(crate) trait SupportsInferenceOutput: InferenceRuntime { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result; } +pub(crate) trait InferenceModelGroup { + type Kind: Copy + Enum; +} + pub(crate) trait InferenceSignature: Sized + Send + 'static { - type Kind: Enum + Copy; + type ModelGroup: InferenceModelGroup; type Input: InferenceInputSignature; type Output: Send; - const KIND: Self::Kind; + const MODEL: ::Kind; } pub(crate) trait InferenceInputSignature: Send + 'static { type Signature: InferenceSignature; } -pub(crate) struct InferenceSessionSet( - EnumMap>>, +pub(crate) struct InferenceSessionSet( + EnumMap>>, ); -impl InferenceSessionSet { +impl InferenceSessionSet { pub(crate) fn new( - model_bytes: &EnumMap>, - mut options: impl FnMut(K) -> InferenceSessionOptions, + model_bytes: &EnumMap>, + mut options: impl FnMut(G::Kind) -> InferenceSessionOptions, ) -> anyhow::Result { let mut sessions = model_bytes .iter() @@ -95,20 +99,20 @@ impl InferenceSessionSet { }) .collect::>>()?; - Ok(Self(EnumMap::::from_fn(|k| { + Ok(Self(EnumMap::::from_fn(|k| { sessions.remove(&k.into_usize()).expect("should exist") }))) } } -impl InferenceSessionSet { +impl InferenceSessionSet { pub(crate) fn get(&self) -> InferenceSessionCell where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { InferenceSessionCell { - inner: self.0[I::Signature::KIND].clone(), + inner: self.0[I::Signature::MODEL].clone(), marker: PhantomData, } } diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index 9d4690c6c..2ef75219a 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -2,12 +2,18 @@ use enum_map::Enum; use ndarray::{Array0, Array1, Array2}; use crate::infer::{ - InferenceInputSignature, InferenceSignature, RunContextExt as _, + InferenceInputSignature, InferenceModelGroup, InferenceSignature, RunContextExt as _, SupportsInferenceInputSignature, SupportsInferenceInputTensor, }; +pub(crate) enum InferenceModelGroupImpl {} + +impl InferenceModelGroup for InferenceModelGroupImpl { + type Kind = InferenceModelKindImpl; +} + #[derive(Clone, Copy, Enum)] -pub(crate) enum InferenceSignatureKind { +pub(crate) enum InferenceModelKindImpl { PredictDuration, PredictIntonation, Decode, @@ -16,10 +22,10 @@ pub(crate) enum InferenceSignatureKind { pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { - type Kind = InferenceSignatureKind; + type ModelGroup = InferenceModelGroupImpl; type Input = PredictDurationInput; type Output = (Vec,); - const KIND: Self::Kind = InferenceSignatureKind::PredictDuration; + const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::PredictDuration; } pub(crate) struct PredictDurationInput { @@ -47,10 +53,10 @@ impl>> pub(crate) enum PredictIntonation {} impl InferenceSignature for PredictIntonation { - type Kind = InferenceSignatureKind; + type ModelGroup = InferenceModelGroupImpl; type Input = PredictIntonationInput; type Output = (Vec,); - const KIND: Self::Kind = InferenceSignatureKind::PredictIntonation; + const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::PredictIntonation; } pub(crate) struct PredictIntonationInput { @@ -90,10 +96,10 @@ impl> + SupportsInferenceInputTensor pub(crate) enum Decode {} impl InferenceSignature for Decode { - type Kind = InferenceSignatureKind; + type ModelGroup = InferenceModelGroupImpl; type Input = DecodeInput; type Output = (Vec,); - const KIND: Self::Kind = InferenceSignatureKind::Decode; + const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::Decode; } pub(crate) struct DecodeInput { diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 3abfe401f..6c88cfe96 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,8 +1,9 @@ use super::*; use crate::infer::{ - signatures::InferenceSignatureKind, InferenceInputSignature, InferenceRuntime, - InferenceSessionCell, InferenceSessionOptions, InferenceSessionSet, InferenceSignature, - SupportsInferenceInputSignature, SupportsInferenceOutput, + signatures::{InferenceModelGroupImpl, InferenceModelKindImpl}, + InferenceInputSignature, InferenceRuntime, InferenceSessionCell, InferenceSessionOptions, + InferenceSessionSet, InferenceSignature, SupportsInferenceInputSignature, + SupportsInferenceOutput, }; use educe::Educe; use itertools::iproduct; @@ -33,10 +34,10 @@ impl Status { let model_bytes = &model.read_inference_models().await?; let session_set = InferenceSessionSet::new(model_bytes, |kind| match kind { - InferenceSignatureKind::PredictDuration | InferenceSignatureKind::PredictIntonation => { + InferenceModelKindImpl::PredictDuration | InferenceModelKindImpl::PredictIntonation => { self.light_session_options } - InferenceSignatureKind::Decode => self.heavy_session_options, + InferenceModelKindImpl::Decode => self.heavy_session_options, }) .map_err(|source| LoadModelError { path: model.path().clone(), @@ -88,7 +89,7 @@ impl Status { ) -> Result<::Output> where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, R: SupportsInferenceInputSignature + SupportsInferenceOutput<::Output>, { @@ -110,7 +111,7 @@ struct LoadedModels(BTreeMap>) struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: InferenceSessionSet, + session_set: InferenceSessionSet, } impl LoadedModels { @@ -152,7 +153,7 @@ impl LoadedModels { fn get(&self, model_id: &VoiceModelId) -> InferenceSessionCell where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { self.0[model_id].session_set.get() } @@ -198,7 +199,7 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - session_set: InferenceSessionSet, + session_set: InferenceSessionSet, ) -> Result<()> { self.ensure_acceptable(model)?; diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 1be2b6a8b..a522bc22e 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -4,7 +4,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::signatures::InferenceSignatureKind; +use crate::infer::signatures::InferenceModelKindImpl; use std::{ collections::{BTreeMap, HashMap}, io, @@ -40,7 +40,7 @@ pub struct VoiceModel { impl VoiceModel { pub(crate) async fn read_inference_models( &self, - ) -> LoadModelResult>> { + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( From c40afd578a2c1231598d2cc6563f7f3c2050b6db Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 10:42:16 +0900 Subject: [PATCH 21/47] =?UTF-8?q?"model"=E3=81=A7=E3=81=AF=E3=81=AA?= =?UTF-8?q?=E3=81=8F"inference"=E3=81=A8=E5=91=BC=E3=81=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 16 +++++++------- crates/voicevox_core/src/infer/signatures.rs | 22 ++++++++++---------- crates/voicevox_core/src/status.rs | 14 ++++++------- crates/voicevox_core/src/voice_model.rs | 4 ++-- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 1559828cf..9df7365f5 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -67,26 +67,26 @@ pub(crate) trait SupportsInferenceOutput: InferenceRuntime { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result; } -pub(crate) trait InferenceModelGroup { +pub(crate) trait InferenceGroup { type Kind: Copy + Enum; } pub(crate) trait InferenceSignature: Sized + Send + 'static { - type ModelGroup: InferenceModelGroup; + type Group: InferenceGroup; type Input: InferenceInputSignature; type Output: Send; - const MODEL: ::Kind; + const INFERENCE: ::Kind; } pub(crate) trait InferenceInputSignature: Send + 'static { type Signature: InferenceSignature; } -pub(crate) struct InferenceSessionSet( +pub(crate) struct InferenceSessionSet( EnumMap>>, ); -impl InferenceSessionSet { +impl InferenceSessionSet { pub(crate) fn new( model_bytes: &EnumMap>, mut options: impl FnMut(G::Kind) -> InferenceSessionOptions, @@ -105,14 +105,14 @@ impl InferenceSessionSet { } } -impl InferenceSessionSet { +impl InferenceSessionSet { pub(crate) fn get(&self) -> InferenceSessionCell where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { InferenceSessionCell { - inner: self.0[I::Signature::MODEL].clone(), + inner: self.0[I::Signature::INFERENCE].clone(), marker: PhantomData, } } diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index 2ef75219a..d04e7b3ad 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -2,18 +2,18 @@ use enum_map::Enum; use ndarray::{Array0, Array1, Array2}; use crate::infer::{ - InferenceInputSignature, InferenceModelGroup, InferenceSignature, RunContextExt as _, + InferenceGroup, InferenceInputSignature, InferenceSignature, RunContextExt as _, SupportsInferenceInputSignature, SupportsInferenceInputTensor, }; -pub(crate) enum InferenceModelGroupImpl {} +pub(crate) enum InferenceGroupImpl {} -impl InferenceModelGroup for InferenceModelGroupImpl { - type Kind = InferenceModelKindImpl; +impl InferenceGroup for InferenceGroupImpl { + type Kind = InferencelKindImpl; } #[derive(Clone, Copy, Enum)] -pub(crate) enum InferenceModelKindImpl { +pub(crate) enum InferencelKindImpl { PredictDuration, PredictIntonation, Decode, @@ -22,10 +22,10 @@ pub(crate) enum InferenceModelKindImpl { pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { - type ModelGroup = InferenceModelGroupImpl; + type Group = InferenceGroupImpl; type Input = PredictDurationInput; type Output = (Vec,); - const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::PredictDuration; + const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictDuration; } pub(crate) struct PredictDurationInput { @@ -53,10 +53,10 @@ impl>> pub(crate) enum PredictIntonation {} impl InferenceSignature for PredictIntonation { - type ModelGroup = InferenceModelGroupImpl; + type Group = InferenceGroupImpl; type Input = PredictIntonationInput; type Output = (Vec,); - const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::PredictIntonation; + const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictIntonation; } pub(crate) struct PredictIntonationInput { @@ -96,10 +96,10 @@ impl> + SupportsInferenceInputTensor pub(crate) enum Decode {} impl InferenceSignature for Decode { - type ModelGroup = InferenceModelGroupImpl; + type Group = InferenceGroupImpl; type Input = DecodeInput; type Output = (Vec,); - const MODEL: InferenceModelKindImpl = InferenceModelKindImpl::Decode; + const INFERENCE: InferencelKindImpl = InferencelKindImpl::Decode; } pub(crate) struct DecodeInput { diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 6c88cfe96..16bc04a4d 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -1,6 +1,6 @@ use super::*; use crate::infer::{ - signatures::{InferenceModelGroupImpl, InferenceModelKindImpl}, + signatures::{InferenceGroupImpl, InferencelKindImpl}, InferenceInputSignature, InferenceRuntime, InferenceSessionCell, InferenceSessionOptions, InferenceSessionSet, InferenceSignature, SupportsInferenceInputSignature, SupportsInferenceOutput, @@ -34,10 +34,10 @@ impl Status { let model_bytes = &model.read_inference_models().await?; let session_set = InferenceSessionSet::new(model_bytes, |kind| match kind { - InferenceModelKindImpl::PredictDuration | InferenceModelKindImpl::PredictIntonation => { + InferencelKindImpl::PredictDuration | InferencelKindImpl::PredictIntonation => { self.light_session_options } - InferenceModelKindImpl::Decode => self.heavy_session_options, + InferencelKindImpl::Decode => self.heavy_session_options, }) .map_err(|source| LoadModelError { path: model.path().clone(), @@ -89,7 +89,7 @@ impl Status { ) -> Result<::Output> where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, R: SupportsInferenceInputSignature + SupportsInferenceOutput<::Output>, { @@ -111,7 +111,7 @@ struct LoadedModels(BTreeMap>) struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: InferenceSessionSet, + session_set: InferenceSessionSet, } impl LoadedModels { @@ -153,7 +153,7 @@ impl LoadedModels { fn get(&self, model_id: &VoiceModelId) -> InferenceSessionCell where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { self.0[model_id].session_set.get() } @@ -199,7 +199,7 @@ impl LoadedModels { fn insert( &mut self, model: &VoiceModel, - session_set: InferenceSessionSet, + session_set: InferenceSessionSet, ) -> Result<()> { self.ensure_acceptable(model)?; diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index a522bc22e..136bc3742 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -4,7 +4,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::signatures::InferenceModelKindImpl; +use crate::infer::signatures::InferencelKindImpl; use std::{ collections::{BTreeMap, HashMap}, io, @@ -40,7 +40,7 @@ pub struct VoiceModel { impl VoiceModel { pub(crate) async fn read_inference_models( &self, - ) -> LoadModelResult>> { + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( From 81b5804037ea21fe0b3fe0ae934fe1d2a2548c74 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 13:34:40 +0900 Subject: [PATCH 22/47] =?UTF-8?q?=E3=83=A9=E3=83=B3=E3=82=BF=E3=82=A4?= =?UTF-8?q?=E3=83=A0=E3=81=AF=E4=BB=BB=E6=84=8F=E6=AC=A1=E5=85=83=E4=BB=BB?= =?UTF-8?q?=E6=84=8F=E5=80=8B=E6=95=B0=E3=81=AE=E5=85=A5=E5=87=BA=E5=8A=9B?= =?UTF-8?q?=E3=81=8C=E3=81=A7=E3=81=8D=E3=82=8B=E3=81=A8=E4=BB=AE=E5=AE=9A?= =?UTF-8?q?=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/engine/synthesis_engine.rs | 12 +- crates/voicevox_core/src/infer.rs | 114 +++++++++------ .../src/infer/runtimes/onnxruntime.rs | 71 +++++---- crates/voicevox_core/src/infer/signatures.rs | 138 +++++++++++++----- crates/voicevox_core/src/inference_core.rs | 29 ++-- crates/voicevox_core/src/status.rs | 5 +- 6 files changed, 226 insertions(+), 143 deletions(-) diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index b171f978d..c70742f16 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -5,10 +5,7 @@ use std::sync::Arc; use super::full_context_label::Utterance; use super::open_jtalk::OpenJtalk; use super::*; -use crate::infer::{ - signatures::{Decode, PredictDuration, PredictIntonation}, - InferenceRuntime, SupportsInferenceSignature, -}; +use crate::infer::InferenceRuntime; use crate::numerics::F32Ext as _; use crate::InferenceCore; @@ -26,12 +23,7 @@ pub(crate) struct SynthesisEngine { open_jtalk: Arc, } -impl< - R: SupportsInferenceSignature - + SupportsInferenceSignature - + SupportsInferenceSignature, - > SynthesisEngine -{ +impl SynthesisEngine { pub fn inference_core(&self) -> &InferenceCore { &self.inference_core } diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 9df7365f5..0eaea5b63 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -7,21 +7,28 @@ use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; use derive_new::new; use easy_ext::ext; use enum_map::{Enum, EnumMap}; +use ndarray::{Array, ArrayD, Dimension, ShapeError}; use thiserror::Error; use crate::{ErrorRepr, SupportedDevices}; pub(crate) trait InferenceRuntime: 'static { - type Session: InferenceSession; + type Session: Sized + Send + 'static; type RunContext<'a>: RunContext<'a, Runtime = Self>; + fn supported_devices() -> crate::Result; -} -pub(crate) trait InferenceSession: Sized + Send + 'static { - fn new( + fn new_session( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: InferenceSessionOptions, - ) -> anyhow::Result; + ) -> anyhow::Result; + + fn push_input( + input: Array, + ctx: &mut Self::RunContext<'_>, + ); + + fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } pub(crate) trait RunContext<'a>: @@ -32,54 +39,56 @@ pub(crate) trait RunContext<'a>: #[ext(RunContextExt)] impl<'a, T: RunContext<'a>> T { - fn with_input(mut self, tensor: I) -> Self - where - T::Runtime: SupportsInferenceInputTensor, - { + fn with_input(mut self, tensor: Array) -> Self { T::Runtime::push_input(tensor, &mut self); self } } -pub(crate) trait SupportsInferenceSignature: - SupportsInferenceInputSignature + SupportsInferenceOutput -{ +pub(crate) trait InferenceGroup { + type Kind: Copy + Enum; } -impl< - R: SupportsInferenceInputSignature + SupportsInferenceOutput, - S: InferenceSignature, - > SupportsInferenceSignature for R -{ +pub(crate) trait InferenceSignature: Sized + Send + 'static { + type Group: InferenceGroup; + type Input: InferenceInputSignature; + type Output: TryFrom, Error = anyhow::Error> + Send; + const INFERENCE: ::Kind; } -pub(crate) trait SupportsInferenceInputTensor: InferenceRuntime { - fn push_input(input: I, ctx: &mut Self::RunContext<'_>); +pub(crate) trait InferenceInputSignature: Send + 'static { + type Signature: InferenceSignature; + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>; } -pub(crate) trait SupportsInferenceInputSignature: - InferenceRuntime -{ - fn make_run_context(sess: &mut Self::Session, input: I) -> Self::RunContext<'_>; -} +pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {} + +impl InputScalar for i64 {} +impl InputScalar for f32 {} -pub(crate) trait SupportsInferenceOutput: InferenceRuntime { - fn run(ctx: Self::RunContext<'_>) -> anyhow::Result; +pub(crate) trait OutputScalar: Sized { + fn extract_dyn_dim(tensor: AnyTensor) -> std::result::Result, ExtractError>; } -pub(crate) trait InferenceGroup { - type Kind: Copy + Enum; +impl OutputScalar for f32 { + fn extract_dyn_dim(tensor: AnyTensor) -> std::result::Result, ExtractError> { + match tensor { + AnyTensor::Float32(tensor) => Ok(tensor), + } + } } -pub(crate) trait InferenceSignature: Sized + Send + 'static { - type Group: InferenceGroup; - type Input: InferenceInputSignature; - type Output: Send; - const INFERENCE: ::Kind; +pub(crate) enum AnyTensor { + Float32(ArrayD), } -pub(crate) trait InferenceInputSignature: Send + 'static { - type Signature: InferenceSignature; +impl TryFrom for Array { + type Error = ExtractError; + + fn try_from(tensor: AnyTensor) -> Result { + let this = A::extract_dyn_dim(tensor)?.into_dimensionality()?; + Ok(this) + } } pub(crate) struct InferenceSessionSet( @@ -94,7 +103,7 @@ impl InferenceSessionSet { let mut sessions = model_bytes .iter() .map(|(k, m)| { - let sess = R::Session::new(|| model_file::decrypt(m), options(k))?; + let sess = R::new_session(|| model_file::decrypt(m), options(k))?; Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) }) .collect::>>()?; @@ -123,19 +132,16 @@ pub(crate) struct InferenceSessionCell { marker: PhantomData, } -impl< - R: SupportsInferenceInputSignature - + SupportsInferenceOutput<::Output>, - I: InferenceInputSignature, - > InferenceSessionCell -{ +impl InferenceSessionCell { pub(crate) fn run( self, input: I, ) -> crate::Result<::Output> { let inner = &mut self.inner.lock().unwrap(); - let ctx = R::make_run_context(inner, input); - R::run(ctx).map_err(|e| ErrorRepr::InferenceFailed(e).into()) + let ctx = input.make_run_context::(inner); + R::run(ctx) + .and_then(TryInto::try_into) + .map_err(|e| ErrorRepr::InferenceFailed(e).into()) } } @@ -145,6 +151,26 @@ pub(crate) struct InferenceSessionOptions { pub(crate) use_gpu: bool, } +#[derive(Error, Debug)] +pub(crate) enum ExtractError { + #[error(transparent)] + Shape(#[from] ShapeError), +} + #[derive(Error, Debug)] #[error("不正なモデルファイルです")] pub(crate) struct DecryptModelError; + +mod sealed { + pub(crate) trait InputScalar: OnnxruntimeInputScalar {} + + impl InputScalar for i64 {} + impl InputScalar for f32 {} + + pub(crate) trait OnnxruntimeInputScalar: + onnxruntime::TypeToTensorElementDataType + { + } + + impl OnnxruntimeInputScalar for T {} +} diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 8848c24e3..26bc93655 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use ndarray::{Array, Dimension}; use once_cell::sync::Lazy; use onnxruntime::{ - environment::Environment, GraphOptimizationLevel, LoggingLevel, TypeToTensorElementDataType, + environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType, }; use self::assert_send::AssertSend; @@ -11,8 +11,8 @@ use crate::{ devices::SupportedDevices, error::ErrorRepr, infer::{ - DecryptModelError, InferenceRuntime, InferenceSession, InferenceSessionOptions, RunContext, - SupportsInferenceInputTensor, SupportsInferenceOutput, + AnyTensor, DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, + RunContext, }, }; @@ -44,13 +44,11 @@ impl InferenceRuntime for Onnxruntime { dml: dml_support, }) } -} -impl InferenceSession for AssertSend> { - fn new( + fn new_session( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: InferenceSessionOptions, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut builder = ENVIRONMENT .new_session_builder()? .with_optimization_level(GraphOptimizationLevel::Basic)? @@ -75,8 +73,8 @@ impl InferenceSession for AssertSend> { } let model = model()?; - let this = builder.with_model_from_memory(model)?.into(); - return Ok(this); + let sess = builder.with_model_from_memory(model)?.into(); + return Ok(sess); static ENVIRONMENT: Lazy = Lazy::new(|| { Environment::builder() @@ -92,6 +90,39 @@ impl InferenceSession for AssertSend> { LoggingLevel::Warning }; } + + fn push_input( + input: Array, + ctx: &mut Self::RunContext<'_>, + ) { + ctx.inputs + .push(Box::new(onnxruntime::session::NdArray::new(input))); + } + + fn run( + OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, + ) -> anyhow::Result> { + // FIXME: 現状では`f32`のみ対応。実行時にsessionからdatatypeが取れるので、別の型の対応も + // おそらく可能ではあるが、それが必要になるよりもortクレートへの引越しが先になると思われる + // のでこのままにする。 + + if !sess + .outputs + .iter() + .all(|info| matches!(info.output_type, TensorElementDataType::Float)) + { + unimplemented!( + "currently only `ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT` is supported for output", + ); + } + + let outputs = sess.run::(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; + + Ok(outputs + .iter() + .map(|o| AnyTensor::Float32((*o).clone().into_owned())) + .collect()) + } } pub(crate) struct OnnxruntimeRunContext<'sess> { @@ -114,28 +145,6 @@ impl<'sess> RunContext<'sess> for OnnxruntimeRunContext<'sess> { type Runtime = Onnxruntime; } -impl - SupportsInferenceInputTensor> for Onnxruntime -{ - fn push_input(input: Array, ctx: &mut Self::RunContext<'_>) { - ctx.inputs - .push(Box::new(onnxruntime::session::NdArray::new(input))); - } -} - -impl SupportsInferenceOutput<(Vec,)> - for Onnxruntime -{ - fn run( - OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, - ) -> anyhow::Result<(Vec,)> { - let outputs = sess.run(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; - - // FIXME: 2個以上の出力や二次元以上の出力をちゃんとしたやりかたで弾く - Ok((outputs[0].as_slice().unwrap().to_owned(),)) - } -} - // FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 // https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 mod assert_send { diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index d04e7b3ad..ac46c6444 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,9 +1,10 @@ +use anyhow::ensure; use enum_map::Enum; use ndarray::{Array0, Array1, Array2}; -use crate::infer::{ - InferenceGroup, InferenceInputSignature, InferenceSignature, RunContextExt as _, - SupportsInferenceInputSignature, SupportsInferenceInputTensor, +use super::{ + AnyTensor, InferenceGroup, InferenceInputSignature, InferenceRuntime, InferenceSignature, + RunContextExt as _, }; pub(crate) enum InferenceGroupImpl {} @@ -24,7 +25,7 @@ pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { type Group = InferenceGroupImpl; type Input = PredictDurationInput; - type Output = (Vec,); + type Output = PredictDurationOutput; const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictDuration; } @@ -35,18 +36,36 @@ pub(crate) struct PredictDurationInput { impl InferenceInputSignature for PredictDurationInput { type Signature = PredictDuration; + + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { + R::RunContext::from(sess) + .with_input(self.phoneme) + .with_input(self.speaker_id) + } +} + +pub(crate) struct PredictDurationOutput { + pub(crate) phoneme_length: Array1, } -impl>> - SupportsInferenceInputSignature for R -{ - fn make_run_context( - sess: &mut Self::Session, - input: PredictDurationInput, - ) -> Self::RunContext<'_> { - Self::RunContext::from(sess) - .with_input(input.phoneme) - .with_input(input.speaker_id) +impl TryFrom> for PredictDurationOutput { + type Error = anyhow::Error; + + fn try_from(tensors: Vec) -> Result { + ensure!( + tensors.len() == 1, + "expected 1 tensor(s), got {}", + tensors.len(), + ); + + let mut tensors = tensors.into_iter(); + let this = Self { + phoneme_length: tensors + .next() + .expect("the length should have been checked") + .try_into()?, + }; + Ok(this) } } @@ -55,7 +74,7 @@ pub(crate) enum PredictIntonation {} impl InferenceSignature for PredictIntonation { type Group = InferenceGroupImpl; type Input = PredictIntonationInput; - type Output = (Vec,); + type Output = PredictIntonationOutput; const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictIntonation; } @@ -72,24 +91,42 @@ pub(crate) struct PredictIntonationInput { impl InferenceInputSignature for PredictIntonationInput { type Signature = PredictIntonation; + + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { + R::RunContext::from(sess) + .with_input(self.length) + .with_input(self.vowel_phoneme) + .with_input(self.consonant_phoneme) + .with_input(self.start_accent) + .with_input(self.end_accent) + .with_input(self.start_accent_phrase) + .with_input(self.end_accent_phrase) + .with_input(self.speaker_id) + } } -impl> + SupportsInferenceInputTensor>> - SupportsInferenceInputSignature for R -{ - fn make_run_context( - sess: &mut Self::Session, - input: PredictIntonationInput, - ) -> Self::RunContext<'_> { - Self::RunContext::from(sess) - .with_input(input.length) - .with_input(input.vowel_phoneme) - .with_input(input.consonant_phoneme) - .with_input(input.start_accent) - .with_input(input.end_accent) - .with_input(input.start_accent_phrase) - .with_input(input.end_accent_phrase) - .with_input(input.speaker_id) +pub(crate) struct PredictIntonationOutput { + pub(crate) f0_list: Array1, +} + +impl TryFrom> for PredictIntonationOutput { + type Error = anyhow::Error; + + fn try_from(tensors: Vec) -> Result { + ensure!( + tensors.len() == 1, + "expected 1 tensor(s), got {}", + tensors.len(), + ); + + let mut tensors = tensors.into_iter(); + let this = Self { + f0_list: tensors + .next() + .expect("the length should have been checked") + .try_into()?, + }; + Ok(this) } } @@ -98,7 +135,7 @@ pub(crate) enum Decode {} impl InferenceSignature for Decode { type Group = InferenceGroupImpl; type Input = DecodeInput; - type Output = (Vec,); + type Output = DecodeOutput; const INFERENCE: InferencelKindImpl = InferencelKindImpl::Decode; } @@ -110,15 +147,36 @@ pub(crate) struct DecodeInput { impl InferenceInputSignature for DecodeInput { type Signature = Decode; + + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { + R::RunContext::from(sess) + .with_input(self.f0) + .with_input(self.phoneme) + .with_input(self.speaker_id) + } +} + +pub(crate) struct DecodeOutput { + pub(crate) wave: Array1, } -impl> + SupportsInferenceInputTensor>> - SupportsInferenceInputSignature for R -{ - fn make_run_context(sess: &mut Self::Session, input: DecodeInput) -> Self::RunContext<'_> { - Self::RunContext::from(sess) - .with_input(input.f0) - .with_input(input.phoneme) - .with_input(input.speaker_id) +impl TryFrom> for DecodeOutput { + type Error = anyhow::Error; + + fn try_from(tensors: Vec) -> Result { + ensure!( + tensors.len() == 1, + "expected 1 tensor(s), got {}", + tensors.len(), + ); + + let mut tensors = tensors.into_iter(); + let this = Self { + wave: tensors + .next() + .expect("the length should have been checked") + .try_into()?, + }; + Ok(this) } } diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 264f56942..30dc37995 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -2,10 +2,10 @@ use self::status::*; use super::*; use crate::infer::{ signatures::{ - Decode, DecodeInput, PredictDuration, PredictDurationInput, PredictIntonation, - PredictIntonationInput, + DecodeInput, DecodeOutput, PredictDurationInput, PredictDurationOutput, + PredictIntonationInput, PredictIntonationOutput, }, - InferenceRuntime, SupportsInferenceSignature, + InferenceRuntime, }; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; @@ -14,12 +14,7 @@ pub(crate) struct InferenceCore { status: Status, } -impl< - R: SupportsInferenceSignature - + SupportsInferenceSignature - + SupportsInferenceSignature, - > InferenceCore -{ +impl InferenceCore { pub(crate) fn new(use_gpu: bool, cpu_num_threads: u16) -> Result { if !use_gpu || Self::can_support_gpu_feature()? { let status = Status::new(use_gpu, cpu_num_threads); @@ -71,7 +66,9 @@ impl< let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let (mut output,) = self + let PredictDurationOutput { + phoneme_length: output, + } = self .status .run_session( &model_id, @@ -81,6 +78,7 @@ impl< }, ) .await?; + let mut output = output.into_raw_vec(); for output_item in output.iter_mut() { if *output_item < PHONEME_LENGTH_MINIMAL { @@ -109,7 +107,7 @@ impl< let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let (output,) = self + let PredictIntonationOutput { f0_list: output } = self .status .run_session( &model_id, @@ -126,7 +124,7 @@ impl< ) .await?; - Ok(output) + Ok(output.into_raw_vec()) } pub async fn decode( @@ -159,7 +157,7 @@ impl< padding_size, ); - let (output,) = self + let DecodeOutput { wave: output } = self .status .run_session( &model_id, @@ -175,7 +173,10 @@ impl< ) .await?; - Ok(Self::trim_padding_from_output(output, padding_size)) + Ok(Self::trim_padding_from_output( + output.into_raw_vec(), + padding_size, + )) } fn make_f0_with_padding( diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 16bc04a4d..51cabf20d 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -2,8 +2,7 @@ use super::*; use crate::infer::{ signatures::{InferenceGroupImpl, InferencelKindImpl}, InferenceInputSignature, InferenceRuntime, InferenceSessionCell, InferenceSessionOptions, - InferenceSessionSet, InferenceSignature, SupportsInferenceInputSignature, - SupportsInferenceOutput, + InferenceSessionSet, InferenceSignature, }; use educe::Educe; use itertools::iproduct; @@ -90,8 +89,6 @@ impl Status { where I: InferenceInputSignature, I::Signature: InferenceSignature, - R: SupportsInferenceInputSignature - + SupportsInferenceOutput<::Output>, { let sess = self.loaded_models.lock().unwrap().get(model_id); From 120106b8d38d4430d7b64b3720fb78ec1a7d09e8 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 17:48:28 +0900 Subject: [PATCH 23/47] =?UTF-8?q?voicevox=5Fcore=5Fmacros=E3=82=92?= =?UTF-8?q?=E4=BD=9C=E3=82=8A=E3=80=81"signatures"=E3=81=AE=E5=AE=9F?= =?UTF-8?q?=E8=A3=85=E3=82=92=E3=83=9E=E3=82=AF=E3=83=AD=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 10 ++ Cargo.toml | 10 +- crates/voicevox_core/Cargo.toml | 1 + crates/voicevox_core/src/infer/signatures.rs | 116 ++------------ crates/voicevox_core_macros/Cargo.toml | 14 ++ crates/voicevox_core_macros/src/lib.rs | 150 +++++++++++++++++++ 6 files changed, 187 insertions(+), 114 deletions(-) create mode 100644 crates/voicevox_core_macros/Cargo.toml create mode 100644 crates/voicevox_core_macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 1ab3e37f5..bd738d44a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4342,6 +4342,7 @@ dependencies = [ "tokio", "tracing", "uuid", + "voicevox_core_macros", "windows", ] @@ -4403,6 +4404,15 @@ dependencies = [ "voicevox_core", ] +[[package]] +name = "voicevox_core_macros" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "voicevox_core_python_api" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index bb98404f2..ddcb356e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,5 @@ [workspace] -members = [ - "crates/downloader", - "crates/test_util", - "crates/voicevox_core", - "crates/voicevox_core_c_api", - "crates/voicevox_core_java_api", - "crates/voicevox_core_python_api", - "crates/xtask" -] +members = ["crates/*"] resolver = "2" [workspace.dependencies] diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index cb8d88d3c..d6ad96761 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -34,6 +34,7 @@ thiserror.workspace = true tokio.workspace = true tracing.workspace = true uuid.workspace = true +voicevox_core_macros = { path = "../voicevox_core_macros" } [dependencies.onnxruntime] git = "https://github.com/VOICEVOX/onnxruntime-rs.git" diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index ac46c6444..469109506 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,11 +1,8 @@ -use anyhow::ensure; use enum_map::Enum; +use macros::{InferenceInputSignature, TryFromVecAnyTensor}; use ndarray::{Array0, Array1, Array2}; -use super::{ - AnyTensor, InferenceGroup, InferenceInputSignature, InferenceRuntime, InferenceSignature, - RunContextExt as _, -}; +use super::{AnyTensor, InferenceGroup, InferenceSignature}; pub(crate) enum InferenceGroupImpl {} @@ -29,46 +26,18 @@ impl InferenceSignature for PredictDuration { const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictDuration; } +#[derive(InferenceInputSignature)] +#[input_signature(Signature = PredictDuration)] pub(crate) struct PredictDurationInput { pub(crate) phoneme: Array1, pub(crate) speaker_id: Array1, } -impl InferenceInputSignature for PredictDurationInput { - type Signature = PredictDuration; - - fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { - R::RunContext::from(sess) - .with_input(self.phoneme) - .with_input(self.speaker_id) - } -} - +#[derive(TryFromVecAnyTensor)] pub(crate) struct PredictDurationOutput { pub(crate) phoneme_length: Array1, } -impl TryFrom> for PredictDurationOutput { - type Error = anyhow::Error; - - fn try_from(tensors: Vec) -> Result { - ensure!( - tensors.len() == 1, - "expected 1 tensor(s), got {}", - tensors.len(), - ); - - let mut tensors = tensors.into_iter(); - let this = Self { - phoneme_length: tensors - .next() - .expect("the length should have been checked") - .try_into()?, - }; - Ok(this) - } -} - pub(crate) enum PredictIntonation {} impl InferenceSignature for PredictIntonation { @@ -78,6 +47,8 @@ impl InferenceSignature for PredictIntonation { const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictIntonation; } +#[derive(InferenceInputSignature)] +#[input_signature(Signature = PredictIntonation)] pub(crate) struct PredictIntonationInput { pub(crate) length: Array0, pub(crate) vowel_phoneme: Array1, @@ -89,47 +60,11 @@ pub(crate) struct PredictIntonationInput { pub(crate) speaker_id: Array1, } -impl InferenceInputSignature for PredictIntonationInput { - type Signature = PredictIntonation; - - fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { - R::RunContext::from(sess) - .with_input(self.length) - .with_input(self.vowel_phoneme) - .with_input(self.consonant_phoneme) - .with_input(self.start_accent) - .with_input(self.end_accent) - .with_input(self.start_accent_phrase) - .with_input(self.end_accent_phrase) - .with_input(self.speaker_id) - } -} - +#[derive(TryFromVecAnyTensor)] pub(crate) struct PredictIntonationOutput { pub(crate) f0_list: Array1, } -impl TryFrom> for PredictIntonationOutput { - type Error = anyhow::Error; - - fn try_from(tensors: Vec) -> Result { - ensure!( - tensors.len() == 1, - "expected 1 tensor(s), got {}", - tensors.len(), - ); - - let mut tensors = tensors.into_iter(); - let this = Self { - f0_list: tensors - .next() - .expect("the length should have been checked") - .try_into()?, - }; - Ok(this) - } -} - pub(crate) enum Decode {} impl InferenceSignature for Decode { @@ -139,44 +74,15 @@ impl InferenceSignature for Decode { const INFERENCE: InferencelKindImpl = InferencelKindImpl::Decode; } +#[derive(InferenceInputSignature)] +#[input_signature(Signature = Decode)] pub(crate) struct DecodeInput { pub(crate) f0: Array2, pub(crate) phoneme: Array2, pub(crate) speaker_id: Array1, } -impl InferenceInputSignature for DecodeInput { - type Signature = Decode; - - fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_> { - R::RunContext::from(sess) - .with_input(self.f0) - .with_input(self.phoneme) - .with_input(self.speaker_id) - } -} - +#[derive(TryFromVecAnyTensor)] pub(crate) struct DecodeOutput { pub(crate) wave: Array1, } - -impl TryFrom> for DecodeOutput { - type Error = anyhow::Error; - - fn try_from(tensors: Vec) -> Result { - ensure!( - tensors.len() == 1, - "expected 1 tensor(s), got {}", - tensors.len(), - ); - - let mut tensors = tensors.into_iter(); - let this = Self { - wave: tensors - .next() - .expect("the length should have been checked") - .try_into()?, - }; - Ok(this) - } -} diff --git a/crates/voicevox_core_macros/Cargo.toml b/crates/voicevox_core_macros/Cargo.toml new file mode 100644 index 000000000..242188b3b --- /dev/null +++ b/crates/voicevox_core_macros/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "voicevox_core_macros" +version.workspace = true +edition.workspace = true +publish.workspace = true + +[lib] +name = "macros" +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.69" +quote = "1.0.33" +syn = "2.0.38" diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs new file mode 100644 index 000000000..d35994bcd --- /dev/null +++ b/crates/voicevox_core_macros/src/lib.rs @@ -0,0 +1,150 @@ +#![warn(rust_2018_idioms)] + +use quote::quote; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, + spanned::Spanned as _, + Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, +}; + +#[proc_macro_derive(InferenceInputSignature, attributes(input_signature))] +pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + return derive_inference_input_signature(&parse_macro_input!(input)) + .unwrap_or_else(|e| e.to_compile_error()) + .into(); + + fn derive_inference_input_signature( + input: &DeriveInput, + ) -> syn::Result { + let DeriveInput { + attrs, + ident, + generics, + data, + .. + } = input; + + let AssocTypeSignature(signature) = attrs + .iter() + .find(|a| a.path().is_ident("input_signature")) + .ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `#[input_signature(…)]`", + ) + })? + .parse_args()?; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let field_names = struct_field_names(data)?; + + Ok(quote! { + impl #impl_generics crate::infer::InferenceInputSignature for #ident #ty_generics + #where_clause + { + type Signature = #signature; + + fn make_run_context( + self, + sess: &mut R::Session, + ) -> R::RunContext<'_> { + let mut ctx = as ::std::convert::From<_>>::from(sess); + #( + R::push_input(self.#field_names, &mut ctx); + )* + ctx + } + } + }) + } + + struct AssocTypeSignature(syn::Ident); + + impl Parse for AssocTypeSignature { + fn parse(input: ParseStream<'_>) -> syn::Result { + let key = input.parse::()?; + if key != "Signature" { + return Err(syn::Error::new(key.span(), "expected `Signature`")); + } + input.parse::()?; + let value = input.parse::()?; + Ok(Self(value)) + } + } +} + +#[proc_macro_derive(TryFromVecAnyTensor)] +pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + return derive_try_from_vec_any_tensor(&parse_macro_input!(input)) + .unwrap_or_else(|e| e.to_compile_error()) + .into(); + + fn derive_try_from_vec_any_tensor( + input: &DeriveInput, + ) -> syn::Result { + let DeriveInput { + ident, + generics, + data, + .. + } = input; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let field_names = struct_field_names(data)?; + let num_fields = field_names.len(); + + Ok(quote! { + impl #impl_generics + ::std::convert::TryFrom<::std::vec::Vec> + for #ident #ty_generics + #where_clause + { + type Error = ::anyhow::Error; + + fn try_from(tensors: ::std::vec::Vec) -> ::std::result::Result { + ::anyhow::ensure!( + tensors.len() == #num_fields, + "expected {} tensor(s), got {}", + #num_fields, + tensors.len(), + ); + + let tensors = &mut ::std::iter::IntoIterator::into_iter(tensors); + ::std::result::Result::Ok(Self { + #( + #field_names: ::std::convert::TryInto::try_into( + ::std::iter::Iterator::next(tensors) + .expect("the length should have been checked"), + )?, + )* + }) + } + } + }) + } +} + +fn struct_field_names(data: &Data) -> syn::Result> { + let fields = match data { + Data::Struct(DataStruct { + fields: Fields::Named(fields), + .. + }) => fields, + Data::Struct(DataStruct { fields, .. }) => { + return Err(syn::Error::new(fields.span(), "expect named fields")); + } + Data::Enum(DataEnum { enum_token, .. }) => { + return Err(syn::Error::new(enum_token.span(), "expected an enum")); + } + Data::Union(DataUnion { union_token, .. }) => { + return Err(syn::Error::new(union_token.span(), "expected an enum")); + } + }; + + Ok(fields + .named + .iter() + .map(|Field { ident, .. }| ident.as_ref().expect("should be named")) + .collect()) +} From 590ce48f0de3b2052518a5204f89ee53774d1b10 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 17:54:43 +0900 Subject: [PATCH 24/47] =?UTF-8?q?`AnyTensor`=20=E2=86=92=20`OutputTensor`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 18 +++++++++--------- .../src/infer/runtimes/onnxruntime.rs | 6 +++--- crates/voicevox_core/src/infer/signatures.rs | 10 +++++----- crates/voicevox_core_macros/src/lib.rs | 8 +++++--- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 0eaea5b63..69c5ab9fd 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -28,7 +28,7 @@ pub(crate) trait InferenceRuntime: 'static { ctx: &mut Self::RunContext<'_>, ); - fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; + fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } pub(crate) trait RunContext<'a>: @@ -52,7 +52,7 @@ pub(crate) trait InferenceGroup { pub(crate) trait InferenceSignature: Sized + Send + 'static { type Group: InferenceGroup; type Input: InferenceInputSignature; - type Output: TryFrom, Error = anyhow::Error> + Send; + type Output: TryFrom, Error = anyhow::Error> + Send; const INFERENCE: ::Kind; } @@ -67,26 +67,26 @@ impl InputScalar for i64 {} impl InputScalar for f32 {} pub(crate) trait OutputScalar: Sized { - fn extract_dyn_dim(tensor: AnyTensor) -> std::result::Result, ExtractError>; + fn extract(tensor: OutputTensor) -> std::result::Result, ExtractError>; } impl OutputScalar for f32 { - fn extract_dyn_dim(tensor: AnyTensor) -> std::result::Result, ExtractError> { + fn extract(tensor: OutputTensor) -> std::result::Result, ExtractError> { match tensor { - AnyTensor::Float32(tensor) => Ok(tensor), + OutputTensor::Float32(tensor) => Ok(tensor), } } } -pub(crate) enum AnyTensor { +pub(crate) enum OutputTensor { Float32(ArrayD), } -impl TryFrom for Array { +impl TryFrom for Array { type Error = ExtractError; - fn try_from(tensor: AnyTensor) -> Result { - let this = A::extract_dyn_dim(tensor)?.into_dimensionality()?; + fn try_from(tensor: OutputTensor) -> Result { + let this = A::extract(tensor)?.into_dimensionality()?; Ok(this) } } diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 26bc93655..45de3e658 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -11,7 +11,7 @@ use crate::{ devices::SupportedDevices, error::ErrorRepr, infer::{ - AnyTensor, DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, + DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, OutputTensor, RunContext, }, }; @@ -101,7 +101,7 @@ impl InferenceRuntime for Onnxruntime { fn run( OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, - ) -> anyhow::Result> { + ) -> anyhow::Result> { // FIXME: 現状では`f32`のみ対応。実行時にsessionからdatatypeが取れるので、別の型の対応も // おそらく可能ではあるが、それが必要になるよりもortクレートへの引越しが先になると思われる // のでこのままにする。 @@ -120,7 +120,7 @@ impl InferenceRuntime for Onnxruntime { Ok(outputs .iter() - .map(|o| AnyTensor::Float32((*o).clone().into_owned())) + .map(|o| OutputTensor::Float32((*o).clone().into_owned())) .collect()) } } diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index 469109506..243f36649 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,8 +1,8 @@ use enum_map::Enum; -use macros::{InferenceInputSignature, TryFromVecAnyTensor}; +use macros::{InferenceInputSignature, TryFromVecOutputTensor}; use ndarray::{Array0, Array1, Array2}; -use super::{AnyTensor, InferenceGroup, InferenceSignature}; +use super::{InferenceGroup, InferenceSignature, OutputTensor}; pub(crate) enum InferenceGroupImpl {} @@ -33,7 +33,7 @@ pub(crate) struct PredictDurationInput { pub(crate) speaker_id: Array1, } -#[derive(TryFromVecAnyTensor)] +#[derive(TryFromVecOutputTensor)] pub(crate) struct PredictDurationOutput { pub(crate) phoneme_length: Array1, } @@ -60,7 +60,7 @@ pub(crate) struct PredictIntonationInput { pub(crate) speaker_id: Array1, } -#[derive(TryFromVecAnyTensor)] +#[derive(TryFromVecOutputTensor)] pub(crate) struct PredictIntonationOutput { pub(crate) f0_list: Array1, } @@ -82,7 +82,7 @@ pub(crate) struct DecodeInput { pub(crate) speaker_id: Array1, } -#[derive(TryFromVecAnyTensor)] +#[derive(TryFromVecOutputTensor)] pub(crate) struct DecodeOutput { pub(crate) wave: Array1, } diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index d35994bcd..192ca83ba 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -74,7 +74,7 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ } } -#[proc_macro_derive(TryFromVecAnyTensor)] +#[proc_macro_derive(TryFromVecOutputTensor)] pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_macro::TokenStream { return derive_try_from_vec_any_tensor(&parse_macro_input!(input)) .unwrap_or_else(|e| e.to_compile_error()) @@ -96,13 +96,15 @@ pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_ma Ok(quote! { impl #impl_generics - ::std::convert::TryFrom<::std::vec::Vec> + ::std::convert::TryFrom<::std::vec::Vec> for #ident #ty_generics #where_clause { type Error = ::anyhow::Error; - fn try_from(tensors: ::std::vec::Vec) -> ::std::result::Result { + fn try_from( + tensors: ::std::vec::Vec, + ) -> ::std::result::Result { ::anyhow::ensure!( tensors.len() == #num_fields, "expected {} tensor(s), got {}", From c4d5ebe3157126fb03253d404aa3177f0fbf195c Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 17:58:55 +0900 Subject: [PATCH 25/47] =?UTF-8?q?`INFERENCE`=20=E2=86=92=20`KIND`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 4 ++-- crates/voicevox_core/src/infer/signatures.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 69c5ab9fd..7124e9a0a 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -53,7 +53,7 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static { type Group: InferenceGroup; type Input: InferenceInputSignature; type Output: TryFrom, Error = anyhow::Error> + Send; - const INFERENCE: ::Kind; + const KIND: ::Kind; } pub(crate) trait InferenceInputSignature: Send + 'static { @@ -121,7 +121,7 @@ impl InferenceSessionSet { I::Signature: InferenceSignature, { InferenceSessionCell { - inner: self.0[I::Signature::INFERENCE].clone(), + inner: self.0[I::Signature::KIND].clone(), marker: PhantomData, } } diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index 243f36649..b7efb6244 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -23,7 +23,7 @@ impl InferenceSignature for PredictDuration { type Group = InferenceGroupImpl; type Input = PredictDurationInput; type Output = PredictDurationOutput; - const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictDuration; + const KIND: InferencelKindImpl = InferencelKindImpl::PredictDuration; } #[derive(InferenceInputSignature)] @@ -44,7 +44,7 @@ impl InferenceSignature for PredictIntonation { type Group = InferenceGroupImpl; type Input = PredictIntonationInput; type Output = PredictIntonationOutput; - const INFERENCE: InferencelKindImpl = InferencelKindImpl::PredictIntonation; + const KIND: InferencelKindImpl = InferencelKindImpl::PredictIntonation; } #[derive(InferenceInputSignature)] @@ -71,7 +71,7 @@ impl InferenceSignature for Decode { type Group = InferenceGroupImpl; type Input = DecodeInput; type Output = DecodeOutput; - const INFERENCE: InferencelKindImpl = InferencelKindImpl::Decode; + const KIND: InferencelKindImpl = InferencelKindImpl::Decode; } #[derive(InferenceInputSignature)] From 1b1b7bfc32d7b7d21cc98cbf219046fc257d75bd Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 20:46:31 +0900 Subject: [PATCH 26/47] =?UTF-8?q?`status`=E3=82=92`infer`=E4=B8=8B?= =?UTF-8?q?=E3=81=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 63 +----- .../voicevox_core/src/{ => infer}/status.rs | 195 +++++++++++++----- crates/voicevox_core/src/inference_core.rs | 30 ++- crates/voicevox_core/src/lib.rs | 1 - 4 files changed, 166 insertions(+), 123 deletions(-) rename crates/voicevox_core/src/{ => infer}/status.rs (54%) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 7124e9a0a..76706b63b 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -1,16 +1,17 @@ mod model_file; pub(crate) mod runtimes; pub(crate) mod signatures; +pub(crate) mod status; -use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; +use std::fmt::Debug; use derive_new::new; use easy_ext::ext; -use enum_map::{Enum, EnumMap}; +use enum_map::Enum; use ndarray::{Array, ArrayD, Dimension, ShapeError}; use thiserror::Error; -use crate::{ErrorRepr, SupportedDevices}; +use crate::SupportedDevices; pub(crate) trait InferenceRuntime: 'static { type Session: Sized + Send + 'static; @@ -91,61 +92,7 @@ impl TryFrom for Array { } } -pub(crate) struct InferenceSessionSet( - EnumMap>>, -); - -impl InferenceSessionSet { - pub(crate) fn new( - model_bytes: &EnumMap>, - mut options: impl FnMut(G::Kind) -> InferenceSessionOptions, - ) -> anyhow::Result { - let mut sessions = model_bytes - .iter() - .map(|(k, m)| { - let sess = R::new_session(|| model_file::decrypt(m), options(k))?; - Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) - }) - .collect::>>()?; - - Ok(Self(EnumMap::::from_fn(|k| { - sessions.remove(&k.into_usize()).expect("should exist") - }))) - } -} - -impl InferenceSessionSet { - pub(crate) fn get(&self) -> InferenceSessionCell - where - I: InferenceInputSignature, - I::Signature: InferenceSignature, - { - InferenceSessionCell { - inner: self.0[I::Signature::KIND].clone(), - marker: PhantomData, - } - } -} - -pub(crate) struct InferenceSessionCell { - inner: Arc>, - marker: PhantomData, -} - -impl InferenceSessionCell { - pub(crate) fn run( - self, - input: I, - ) -> crate::Result<::Output> { - let inner = &mut self.inner.lock().unwrap(); - let ctx = input.make_run_context::(inner); - R::run(ctx) - .and_then(TryInto::try_into) - .map_err(|e| ErrorRepr::InferenceFailed(e).into()) - } -} - -#[derive(new, Clone, Copy)] +#[derive(new, Clone, Copy, PartialEq, Debug)] pub(crate) struct InferenceSessionOptions { pub(crate) cpu_num_threads: u16, pub(crate) use_gpu: bool, diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/infer/status.rs similarity index 54% rename from crates/voicevox_core/src/status.rs rename to crates/voicevox_core/src/infer/status.rs index 51cabf20d..4938f89a6 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -1,48 +1,57 @@ -use super::*; -use crate::infer::{ - signatures::{InferenceGroupImpl, InferencelKindImpl}, - InferenceInputSignature, InferenceRuntime, InferenceSessionCell, InferenceSessionOptions, - InferenceSessionSet, InferenceSignature, +use std::{ + collections::{BTreeMap, HashMap}, + marker::PhantomData, + sync::Arc, }; + use educe::Educe; +use enum_map::{Enum as _, EnumMap}; use itertools::iproduct; -use std::collections::BTreeMap; +use crate::{ + error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult}, + manifest::ModelInnerId, + metas::{SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta}, + voice_model::{VoiceModel, VoiceModelId}, + Result, +}; + +use super::{ + model_file, InferenceGroup, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions, + InferenceSignature, +}; -pub(crate) struct Status { - loaded_models: std::sync::Mutex>, - light_session_options: InferenceSessionOptions, // 軽いモデルはこちらを使う - heavy_session_options: InferenceSessionOptions, // 重いモデルはこちらを使う +pub(crate) struct Status { + loaded_models: std::sync::Mutex>, + session_options: EnumMap, } -impl Status { - pub fn new(use_gpu: bool, cpu_num_threads: u16) -> Self { +impl Status { + pub fn new(session_options: EnumMap) -> Self { Self { loaded_models: Default::default(), - light_session_options: InferenceSessionOptions::new(cpu_num_threads, false), - heavy_session_options: InferenceSessionOptions::new(cpu_num_threads, use_gpu), + session_options, } } - pub async fn load_model(&self, model: &VoiceModel) -> Result<()> { + pub async fn load_model( + &self, + model: &VoiceModel, + model_bytes: &EnumMap>, + ) -> Result<()> { self.loaded_models .lock() .unwrap() .ensure_acceptable(model)?; - let model_bytes = &model.read_inference_models().await?; - - let session_set = InferenceSessionSet::new(model_bytes, |kind| match kind { - InferencelKindImpl::PredictDuration | InferencelKindImpl::PredictIntonation => { - self.light_session_options - } - InferencelKindImpl::Decode => self.heavy_session_options, - }) - .map_err(|source| LoadModelError { - path: model.path().clone(), - context: LoadModelErrorKind::InvalidModelData, - source: Some(source), - })?; + let session_set = + SessionSet::new(model_bytes, &self.session_options).map_err(|source| { + LoadModelError { + path: model.path().clone(), + context: LoadModelErrorKind::InvalidModelData, + source: Some(source), + } + })?; self.loaded_models .lock() @@ -88,7 +97,7 @@ impl Status { ) -> Result<::Output> where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { let sess = self.loaded_models.lock().unwrap().get(model_id); @@ -102,16 +111,18 @@ impl Status { /// /// この構造体のメソッドは、すべて一瞬で完了すべきである。 #[derive(Educe)] -#[educe(Default(bound = "R: InferenceRuntime"))] -struct LoadedModels(BTreeMap>); +#[educe(Default(bound = "R: InferenceRuntime, G: InferenceGroup"))] +struct LoadedModels( + BTreeMap>, +); -struct LoadedModel { +struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: InferenceSessionSet, + session_set: SessionSet, } -impl LoadedModels { +impl LoadedModels { fn metas(&self) -> VoiceModelMeta { self.0 .values() @@ -147,10 +158,10 @@ impl LoadedModels { /// # Panics /// /// `self`が`model_id`を含んでいないとき、パニックする。 - fn get(&self, model_id: &VoiceModelId) -> InferenceSessionCell + fn get(&self, model_id: &VoiceModelId) -> SessionCell where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { self.0[model_id].session_set.get() } @@ -193,11 +204,7 @@ impl LoadedModels { Ok(()) } - fn insert( - &mut self, - model: &VoiceModel, - session_set: InferenceSessionSet, - ) -> Result<()> { + fn insert(&mut self, model: &VoiceModel, session_set: SessionSet) -> Result<()> { self.ensure_acceptable(model)?; let prev = self.0.insert( @@ -230,13 +237,71 @@ impl LoadedModels { } } +struct SessionSet( + EnumMap>>, +); + +impl SessionSet { + fn new( + model_bytes: &EnumMap>, + options: &EnumMap, + ) -> anyhow::Result { + let mut sessions = model_bytes + .iter() + .map(|(k, m)| { + let sess = R::new_session(|| model_file::decrypt(m), options[k])?; + Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) + }) + .collect::>>()?; + + Ok(Self(EnumMap::::from_fn(|k| { + sessions.remove(&k.into_usize()).expect("should exist") + }))) + } +} + +impl SessionSet { + fn get(&self) -> SessionCell + where + I: InferenceInputSignature, + I::Signature: InferenceSignature, + { + SessionCell { + inner: self.0[I::Signature::KIND].clone(), + marker: PhantomData, + } + } +} + +struct SessionCell { + inner: Arc>, + marker: PhantomData, +} + +impl SessionCell { + fn run(self, input: I) -> crate::Result<::Output> { + let inner = &mut self.inner.lock().unwrap(); + let ctx = input.make_run_context::(inner); + R::run(ctx) + .and_then(TryInto::try_into) + .map_err(|e| ErrorRepr::InferenceFailed(e).into()) + } +} + #[cfg(test)] mod tests { - - use super::*; - use crate::macros::tests::assert_debug_fmt_eq; - use crate::synthesizer::InferenceRuntimeImpl; + use enum_map::enum_map; use pretty_assertions::assert_eq; + use rstest::rstest; + + use crate::{ + infer::signatures::{InferenceGroupImpl, InferencelKindImpl}, + macros::tests::assert_debug_fmt_eq, + synthesizer::InferenceRuntimeImpl, + test_util::open_default_vvm_file, + }; + + use super::{super::InferenceSessionOptions, Status}; #[rstest] #[case(true, 0)] @@ -247,25 +312,40 @@ mod tests { #[case(false, 8)] #[case(false, 0)] fn status_new_works(#[case] use_gpu: bool, #[case] cpu_num_threads: u16) { - let status = Status::::new(use_gpu, cpu_num_threads); - assert_eq!(false, status.light_session_options.use_gpu); - assert_eq!(use_gpu, status.heavy_session_options.use_gpu); + let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false); + let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); + let session_options = enum_map! { + InferencelKindImpl::PredictDuration + | InferencelKindImpl::PredictIntonation => light_session_options, + InferencelKindImpl::Decode => heavy_session_options, + }; + let status = Status::::new(session_options); + assert_eq!( - cpu_num_threads, - status.light_session_options.cpu_num_threads + light_session_options, + status.session_options[InferencelKindImpl::PredictDuration], ); assert_eq!( - cpu_num_threads, - status.heavy_session_options.cpu_num_threads + light_session_options, + status.session_options[InferencelKindImpl::PredictIntonation], ); + assert_eq!( + heavy_session_options, + status.session_options[InferencelKindImpl::Decode], + ); + assert!(status.loaded_models.lock().unwrap().0.is_empty()); } #[rstest] #[tokio::test] async fn status_load_model_works() { - let status = Status::::new(false, 0); - let result = status.load_model(&open_default_vvm_file().await).await; + let status = Status::::new( + enum_map!(_ => InferenceSessionOptions::new(0, false)), + ); + let model = &open_default_vvm_file().await; + let model_bytes = &model.read_inference_models().await.unwrap(); + let result = status.load_model(model, model_bytes).await; assert_debug_fmt_eq!(Ok(()), result); assert_eq!(1, status.loaded_models.lock().unwrap().0.len()); } @@ -273,13 +353,16 @@ mod tests { #[rstest] #[tokio::test] async fn status_is_model_loaded_works() { - let status = Status::::new(false, 0); + let status = Status::::new( + enum_map!(_ => InferenceSessionOptions::new(0, false)), + ); let vvm = open_default_vvm_file().await; + let model_bytes = &vvm.read_inference_models().await.unwrap(); assert!( !status.is_loaded_model(vvm.id()), "model should not be loaded" ); - let result = status.load_model(&vvm).await; + let result = status.load_model(&vvm, model_bytes).await; assert_debug_fmt_eq!(Ok(()), result); assert!(status.is_loaded_model(vvm.id()), "model should be loaded"); } diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 30dc37995..5987a0df4 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,23 +1,36 @@ -use self::status::*; -use super::*; +use enum_map::enum_map; + use crate::infer::{ signatures::{ - DecodeInput, DecodeOutput, PredictDurationInput, PredictDurationOutput, - PredictIntonationInput, PredictIntonationOutput, + DecodeInput, DecodeOutput, InferenceGroupImpl, InferencelKindImpl, PredictDurationInput, + PredictDurationOutput, PredictIntonationInput, PredictIntonationOutput, }, - InferenceRuntime, + status::Status, + InferenceRuntime, InferenceSessionOptions, }; +use super::*; + const PHONEME_LENGTH_MINIMAL: f32 = 0.01; pub(crate) struct InferenceCore { - status: Status, + status: Status, } impl InferenceCore { pub(crate) fn new(use_gpu: bool, cpu_num_threads: u16) -> Result { if !use_gpu || Self::can_support_gpu_feature()? { - let status = Status::new(use_gpu, cpu_num_threads); + // 軽いモデルはこちらを使う + let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false); + + // 重いモデルはこちらを使う + let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); + + let status = Status::new(enum_map! { + InferencelKindImpl::PredictDuration + | InferencelKindImpl::PredictIntonation => light_session_options, + InferencelKindImpl::Decode => heavy_session_options, + }); Ok(Self { status }) } else { Err(ErrorRepr::GpuSupport.into()) @@ -37,7 +50,8 @@ impl InferenceCore { } pub async fn load_model(&self, model: &VoiceModel) -> Result<()> { - self.status.load_model(model).await + let model_bytes = &model.read_inference_models().await?; + self.status.load_model(model, model_bytes).await } pub fn unload_model(&self, voice_model_id: &VoiceModelId) -> Result<()> { diff --git a/crates/voicevox_core/src/lib.rs b/crates/voicevox_core/src/lib.rs index 407f0b8f4..dc54551a7 100644 --- a/crates/voicevox_core/src/lib.rs +++ b/crates/voicevox_core/src/lib.rs @@ -13,7 +13,6 @@ mod manifest; mod metas; mod numerics; mod result; -mod status; mod synthesizer; mod user_dict; mod version; From c39f48cb5c560b6f4dde212767ba78beba941606 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 20:49:21 +0900 Subject: [PATCH 27/47] =?UTF-8?q?`trait=20RunContext`=E3=82=92=E5=89=8A?= =?UTF-8?q?=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 17 +---------------- .../src/infer/runtimes/onnxruntime.rs | 5 ----- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 76706b63b..260235543 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -6,7 +6,6 @@ pub(crate) mod status; use std::fmt::Debug; use derive_new::new; -use easy_ext::ext; use enum_map::Enum; use ndarray::{Array, ArrayD, Dimension, ShapeError}; use thiserror::Error; @@ -15,7 +14,7 @@ use crate::SupportedDevices; pub(crate) trait InferenceRuntime: 'static { type Session: Sized + Send + 'static; - type RunContext<'a>: RunContext<'a, Runtime = Self>; + type RunContext<'a>: From<&'a mut Self::Session>; fn supported_devices() -> crate::Result; @@ -32,20 +31,6 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } -pub(crate) trait RunContext<'a>: - From<&'a mut ::Session> -{ - type Runtime: InferenceRuntime = Self>; -} - -#[ext(RunContextExt)] -impl<'a, T: RunContext<'a>> T { - fn with_input(mut self, tensor: Array) -> Self { - T::Runtime::push_input(tensor, &mut self); - self - } -} - pub(crate) trait InferenceGroup { type Kind: Copy + Enum; } diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 45de3e658..6c3901d04 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -12,7 +12,6 @@ use crate::{ error::ErrorRepr, infer::{ DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, OutputTensor, - RunContext, }, }; @@ -141,10 +140,6 @@ impl<'sess> From<&'sess mut AssertSend>> } } -impl<'sess> RunContext<'sess> for OnnxruntimeRunContext<'sess> { - type Runtime = Onnxruntime; -} - // FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 // https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 mod assert_send { From c316209419cc81f96d16739ac10aaba1c6a907d1 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sat, 11 Nov 2023 22:49:57 +0900 Subject: [PATCH 28/47] =?UTF-8?q?"kind"=E3=82=92=E7=9B=B4=E6=8E=A5"group"?= =?UTF-8?q?=E3=81=A8=E5=91=BC=E3=81=B6=E3=81=93=E3=81=A8=E3=81=AB=E3=81=99?= =?UTF-8?q?=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 6 +-- crates/voicevox_core/src/infer/signatures.rs | 26 +++++-------- crates/voicevox_core/src/infer/status.rs | 40 ++++++++++---------- crates/voicevox_core/src/inference_core.rs | 12 +++--- crates/voicevox_core/src/voice_model.rs | 4 +- crates/voicevox_core_macros/src/lib.rs | 6 +++ 6 files changed, 45 insertions(+), 49 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 260235543..2a0ac3318 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -31,15 +31,13 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } -pub(crate) trait InferenceGroup { - type Kind: Copy + Enum; -} +pub(crate) trait InferenceGroup: Copy + Enum {} pub(crate) trait InferenceSignature: Sized + Send + 'static { type Group: InferenceGroup; type Input: InferenceInputSignature; type Output: TryFrom, Error = anyhow::Error> + Send; - const KIND: ::Kind; + const KIND: Self::Group; } pub(crate) trait InferenceInputSignature: Send + 'static { diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index b7efb6244..bce6f62da 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,17 +1,11 @@ use enum_map::Enum; -use macros::{InferenceInputSignature, TryFromVecOutputTensor}; +use macros::{InferenceGroup, InferenceInputSignature, TryFromVecOutputTensor}; use ndarray::{Array0, Array1, Array2}; -use super::{InferenceGroup, InferenceSignature, OutputTensor}; +use super::{InferenceSignature, OutputTensor}; -pub(crate) enum InferenceGroupImpl {} - -impl InferenceGroup for InferenceGroupImpl { - type Kind = InferencelKindImpl; -} - -#[derive(Clone, Copy, Enum)] -pub(crate) enum InferencelKindImpl { +#[derive(Clone, Copy, Enum, InferenceGroup)] +pub(crate) enum InferenceKind { PredictDuration, PredictIntonation, Decode, @@ -20,10 +14,10 @@ pub(crate) enum InferencelKindImpl { pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { - type Group = InferenceGroupImpl; + type Group = InferenceKind; type Input = PredictDurationInput; type Output = PredictDurationOutput; - const KIND: InferencelKindImpl = InferencelKindImpl::PredictDuration; + const KIND: InferenceKind = InferenceKind::PredictDuration; } #[derive(InferenceInputSignature)] @@ -41,10 +35,10 @@ pub(crate) struct PredictDurationOutput { pub(crate) enum PredictIntonation {} impl InferenceSignature for PredictIntonation { - type Group = InferenceGroupImpl; + type Group = InferenceKind; type Input = PredictIntonationInput; type Output = PredictIntonationOutput; - const KIND: InferencelKindImpl = InferencelKindImpl::PredictIntonation; + const KIND: InferenceKind = InferenceKind::PredictIntonation; } #[derive(InferenceInputSignature)] @@ -68,10 +62,10 @@ pub(crate) struct PredictIntonationOutput { pub(crate) enum Decode {} impl InferenceSignature for Decode { - type Group = InferenceGroupImpl; + type Group = InferenceKind; type Input = DecodeInput; type Output = DecodeOutput; - const KIND: InferencelKindImpl = InferencelKindImpl::Decode; + const KIND: InferenceKind = InferenceKind::Decode; } #[derive(InferenceInputSignature)] diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index 4938f89a6..587ce21fa 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -5,7 +5,7 @@ use std::{ }; use educe::Educe; -use enum_map::{Enum as _, EnumMap}; +use enum_map::EnumMap; use itertools::iproduct; use crate::{ @@ -23,11 +23,11 @@ use super::{ pub(crate) struct Status { loaded_models: std::sync::Mutex>, - session_options: EnumMap, + session_options: EnumMap, } impl Status { - pub fn new(session_options: EnumMap) -> Self { + pub fn new(session_options: EnumMap) -> Self { Self { loaded_models: Default::default(), session_options, @@ -37,7 +37,7 @@ impl Status { pub async fn load_model( &self, model: &VoiceModel, - model_bytes: &EnumMap>, + model_bytes: &EnumMap>, ) -> Result<()> { self.loaded_models .lock() @@ -238,13 +238,13 @@ impl LoadedModels { } struct SessionSet( - EnumMap>>, + EnumMap>>, ); impl SessionSet { fn new( - model_bytes: &EnumMap>, - options: &EnumMap, + model_bytes: &EnumMap>, + options: &EnumMap, ) -> anyhow::Result { let mut sessions = model_bytes .iter() @@ -254,7 +254,7 @@ impl SessionSet { }) .collect::>>()?; - Ok(Self(EnumMap::::from_fn(|k| { + Ok(Self(EnumMap::::from_fn(|k| { sessions.remove(&k.into_usize()).expect("should exist") }))) } @@ -295,10 +295,8 @@ mod tests { use rstest::rstest; use crate::{ - infer::signatures::{InferenceGroupImpl, InferencelKindImpl}, - macros::tests::assert_debug_fmt_eq, - synthesizer::InferenceRuntimeImpl, - test_util::open_default_vvm_file, + infer::signatures::InferenceKind, macros::tests::assert_debug_fmt_eq, + synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file, }; use super::{super::InferenceSessionOptions, Status}; @@ -315,23 +313,23 @@ mod tests { let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false); let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); let session_options = enum_map! { - InferencelKindImpl::PredictDuration - | InferencelKindImpl::PredictIntonation => light_session_options, - InferencelKindImpl::Decode => heavy_session_options, + InferenceKind::PredictDuration + | InferenceKind::PredictIntonation => light_session_options, + InferenceKind::Decode => heavy_session_options, }; - let status = Status::::new(session_options); + let status = Status::::new(session_options); assert_eq!( light_session_options, - status.session_options[InferencelKindImpl::PredictDuration], + status.session_options[InferenceKind::PredictDuration], ); assert_eq!( light_session_options, - status.session_options[InferencelKindImpl::PredictIntonation], + status.session_options[InferenceKind::PredictIntonation], ); assert_eq!( heavy_session_options, - status.session_options[InferencelKindImpl::Decode], + status.session_options[InferenceKind::Decode], ); assert!(status.loaded_models.lock().unwrap().0.is_empty()); @@ -340,7 +338,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_load_model_works() { - let status = Status::::new( + let status = Status::::new( enum_map!(_ => InferenceSessionOptions::new(0, false)), ); let model = &open_default_vvm_file().await; @@ -353,7 +351,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_is_model_loaded_works() { - let status = Status::::new( + let status = Status::::new( enum_map!(_ => InferenceSessionOptions::new(0, false)), ); let vvm = open_default_vvm_file().await; diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 5987a0df4..5fdbdf6dd 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -2,8 +2,8 @@ use enum_map::enum_map; use crate::infer::{ signatures::{ - DecodeInput, DecodeOutput, InferenceGroupImpl, InferencelKindImpl, PredictDurationInput, - PredictDurationOutput, PredictIntonationInput, PredictIntonationOutput, + DecodeInput, DecodeOutput, InferenceKind, PredictDurationInput, PredictDurationOutput, + PredictIntonationInput, PredictIntonationOutput, }, status::Status, InferenceRuntime, InferenceSessionOptions, @@ -14,7 +14,7 @@ use super::*; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; pub(crate) struct InferenceCore { - status: Status, + status: Status, } impl InferenceCore { @@ -27,9 +27,9 @@ impl InferenceCore { let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); let status = Status::new(enum_map! { - InferencelKindImpl::PredictDuration - | InferencelKindImpl::PredictIntonation => light_session_options, - InferencelKindImpl::Decode => heavy_session_options, + InferenceKind::PredictDuration + | InferenceKind::PredictIntonation => light_session_options, + InferenceKind::Decode => heavy_session_options, }); Ok(Self { status }) } else { diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 136bc3742..5b75bcacf 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -4,7 +4,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::signatures::InferencelKindImpl; +use crate::infer::signatures::InferenceKind; use std::{ collections::{BTreeMap, HashMap}, io, @@ -40,7 +40,7 @@ pub struct VoiceModel { impl VoiceModel { pub(crate) async fn read_inference_models( &self, - ) -> LoadModelResult>> { + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index 192ca83ba..dcd235537 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -8,6 +8,12 @@ use syn::{ Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, }; +#[proc_macro_derive(InferenceGroup)] +pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let DeriveInput { ident, .. } = parse_macro_input!(input as DeriveInput); + quote!(impl crate::infer::InferenceGroup for #ident {}).into() +} + #[proc_macro_derive(InferenceInputSignature, attributes(input_signature))] pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream { return derive_inference_input_signature(&parse_macro_input!(input)) From 2274a34d1382a4c528228c5caa63c2f5cdad7b35 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Sun, 12 Nov 2023 10:58:31 +0900 Subject: [PATCH 29/47] =?UTF-8?q?=E3=82=B7=E3=82=B0=E3=83=8D=E3=83=81?= =?UTF-8?q?=E3=83=A3=E3=81=AE=E5=AE=9F=E8=A1=8C=E6=99=82=E3=83=81=E3=82=A7?= =?UTF-8?q?=E3=83=83=E3=82=AF=E6=A9=9F=E6=A7=8B=E3=82=92=E5=85=A5=E3=82=8C?= =?UTF-8?q?=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 94 +++++++++++++++++-- .../src/infer/runtimes/onnxruntime.rs | 84 +++++++++++++++-- crates/voicevox_core/src/infer/signatures.rs | 52 +++++++--- crates/voicevox_core/src/infer/status.rs | 36 ++++++- crates/voicevox_core/src/inference_core.rs | 14 +-- crates/voicevox_core_macros/src/lib.rs | 92 +++++++++++++++--- 6 files changed, 317 insertions(+), 55 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 2a0ac3318..a8577bd24 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -3,10 +3,13 @@ pub(crate) mod runtimes; pub(crate) mod signatures; pub(crate) mod status; -use std::fmt::Debug; +use std::{ + borrow::Cow, + fmt::{self, Debug, Display}, +}; use derive_new::new; -use enum_map::Enum; +use enum_map::{Enum, EnumMap}; use ndarray::{Array, ArrayD, Dimension, ShapeError}; use thiserror::Error; @@ -18,10 +21,15 @@ pub(crate) trait InferenceRuntime: 'static { fn supported_devices() -> crate::Result; + #[allow(clippy::type_complexity)] fn new_session( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: InferenceSessionOptions, - ) -> anyhow::Result; + ) -> anyhow::Result<( + Self::Session, + Vec>, + Vec>, + )>; fn push_input( input: Array, @@ -31,30 +39,59 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } -pub(crate) trait InferenceGroup: Copy + Enum {} +pub(crate) trait InferenceGroup: Copy + Enum { + const INPUT_PARAM_INFOS: EnumMap]>; + const OUTPUT_PARAM_INFOS: EnumMap]>; +} pub(crate) trait InferenceSignature: Sized + Send + 'static { type Group: InferenceGroup; type Input: InferenceInputSignature; - type Output: TryFrom, Error = anyhow::Error> + Send; + type Output: InferenceOutputSignature; const KIND: Self::Group; } pub(crate) trait InferenceInputSignature: Send + 'static { type Signature: InferenceSignature; + const PARAM_INFOS: &'static [ParamInfo]; fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>; } -pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static {} +pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static { + const KIND: InputScalarKind; +} -impl InputScalar for i64 {} -impl InputScalar for f32 {} +impl InputScalar for i64 { + const KIND: InputScalarKind = InputScalarKind::Int64; +} + +impl InputScalar for f32 { + const KIND: InputScalarKind = InputScalarKind::Float32; +} + +#[derive(Clone, Copy, PartialEq, derive_more::Display)] +pub(crate) enum InputScalarKind { + #[display(fmt = "int64_t")] + Int64, + + #[display(fmt = "float")] + Float32, +} + +pub(crate) trait InferenceOutputSignature: + TryFrom, Error = anyhow::Error> + Send +{ + const PARAM_INFOS: &'static [ParamInfo]; +} pub(crate) trait OutputScalar: Sized { + const KIND: OutputScalarKind; fn extract(tensor: OutputTensor) -> std::result::Result, ExtractError>; } impl OutputScalar for f32 { + const KIND: OutputScalarKind = OutputScalarKind::Float32; + fn extract(tensor: OutputTensor) -> std::result::Result, ExtractError> { match tensor { OutputTensor::Float32(tensor) => Ok(tensor), @@ -62,6 +99,12 @@ impl OutputScalar for f32 { } } +#[derive(Clone, Copy, PartialEq, derive_more::Display)] +pub(crate) enum OutputScalarKind { + #[display(fmt = "float")] + Float32, +} + pub(crate) enum OutputTensor { Float32(ArrayD), } @@ -75,6 +118,41 @@ impl TryFrom for Array { } } +pub(crate) struct ParamInfo { + name: Cow<'static, str>, + dt: D, + ndim: Option, +} + +impl ParamInfo { + fn accepts(&self, other: &Self) -> bool { + self.name == other.name + && self.dt == other.dt + && (self.ndim.is_none() || self.ndim == other.ndim) + } +} + +impl Display for ParamInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}: {}", self.name, self.dt)?; + if let Some(ndim) = self.ndim { + f.write_str(&"[]".repeat(ndim)) + } else { + f.write_str("[]...") + } + } +} + +pub(crate) trait ArrayExt { + type Scalar; + type Dimension: Dimension; +} + +impl ArrayExt for Array { + type Scalar = A; + type Dimension = D; +} + #[derive(new, Clone, Copy, PartialEq, Debug)] pub(crate) struct InferenceSessionOptions { pub(crate) cpu_num_threads: u16, diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs index 6c3901d04..ca5b28aaa 100644 --- a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -1,18 +1,19 @@ -use std::fmt::Debug; +use std::{fmt::Debug, vec}; +use anyhow::anyhow; use ndarray::{Array, Dimension}; use once_cell::sync::Lazy; use onnxruntime::{ environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType, }; +use crate::{devices::SupportedDevices, error::ErrorRepr}; + use self::assert_send::AssertSend; -use crate::{ - devices::SupportedDevices, - error::ErrorRepr, - infer::{ - DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, OutputTensor, - }, + +use super::super::{ + DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, InputScalarKind, + OutputScalarKind, OutputTensor, ParamInfo, }; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -47,7 +48,11 @@ impl InferenceRuntime for Onnxruntime { fn new_session( model: impl FnOnce() -> std::result::Result, DecryptModelError>, options: InferenceSessionOptions, - ) -> anyhow::Result { + ) -> anyhow::Result<( + Self::Session, + Vec>, + Vec>, + )> { let mut builder = ENVIRONMENT .new_session_builder()? .with_optimization_level(GraphOptimizationLevel::Basic)? @@ -72,8 +77,67 @@ impl InferenceRuntime for Onnxruntime { } let model = model()?; - let sess = builder.with_model_from_memory(model)?.into(); - return Ok(sess); + let sess = AssertSend::from(builder.with_model_from_memory(model)?); + + let input_param_infos = sess + .inputs + .iter() + .map(|info| { + let dt = match info.input_type { + TensorElementDataType::Float => Ok(InputScalarKind::Float32), + TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), + TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), + TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), + TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), + TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), + TensorElementDataType::Int64 => Ok(InputScalarKind::Int64), + TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), + TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), + TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), + TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + } + .map_err(|actual| { + anyhow!("unsupported input datatype `{actual}` for `{}`", info.name) + })?; + + Ok(ParamInfo { + name: info.name.clone().into(), + dt, + ndim: Some(info.dimensions.len()), + }) + }) + .collect::>()?; + + let output_param_infos = sess + .outputs + .iter() + .map(|info| { + let dt = match info.output_type { + TensorElementDataType::Float => Ok(OutputScalarKind::Float32), + TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), + TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), + TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), + TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), + TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), + TensorElementDataType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"), + TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), + TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), + TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), + TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + } + .map_err(|actual| { + anyhow!("unsupported output datatype `{actual}` for `{}`", info.name) + })?; + + Ok(ParamInfo { + name: info.name.clone().into(), + dt, + ndim: Some(info.dimensions.len()), + }) + }) + .collect::>()?; + + return Ok((sess, input_param_infos, output_param_infos)); static ENVIRONMENT: Lazy = Lazy::new(|| { Environment::builder() diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index bce6f62da..c4633658d 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,16 +1,40 @@ -use enum_map::Enum; -use macros::{InferenceGroup, InferenceInputSignature, TryFromVecOutputTensor}; +use enum_map::{Enum, EnumMap}; +use macros::{InferenceInputSignature, InferenceOutputSignature}; use ndarray::{Array0, Array1, Array2}; -use super::{InferenceSignature, OutputTensor}; +use super::{ + InferenceGroup, InferenceInputSignature as _, InferenceOutputSignature as _, + InferenceSignature, OutputTensor, +}; -#[derive(Clone, Copy, Enum, InferenceGroup)] +#[derive(Clone, Copy, Enum)] pub(crate) enum InferenceKind { PredictDuration, PredictIntonation, Decode, } +// FIXME: ここもマクロ化する +impl InferenceGroup for InferenceKind { + const INPUT_PARAM_INFOS: enum_map::EnumMap< + Self, + &'static [super::ParamInfo], + > = EnumMap::from_array([ + PredictDurationInput::PARAM_INFOS, + PredictIntonationInput::PARAM_INFOS, + DecodeInput::PARAM_INFOS, + ]); + + const OUTPUT_PARAM_INFOS: enum_map::EnumMap< + Self, + &'static [super::ParamInfo], + > = EnumMap::from_array([ + PredictDurationOutput::PARAM_INFOS, + PredictIntonationOutput::PARAM_INFOS, + DecodeOutput::PARAM_INFOS, + ]); +} + pub(crate) enum PredictDuration {} impl InferenceSignature for PredictDuration { @@ -23,11 +47,11 @@ impl InferenceSignature for PredictDuration { #[derive(InferenceInputSignature)] #[input_signature(Signature = PredictDuration)] pub(crate) struct PredictDurationInput { - pub(crate) phoneme: Array1, + pub(crate) phoneme_list: Array1, pub(crate) speaker_id: Array1, } -#[derive(TryFromVecOutputTensor)] +#[derive(InferenceOutputSignature)] pub(crate) struct PredictDurationOutput { pub(crate) phoneme_length: Array1, } @@ -45,16 +69,16 @@ impl InferenceSignature for PredictIntonation { #[input_signature(Signature = PredictIntonation)] pub(crate) struct PredictIntonationInput { pub(crate) length: Array0, - pub(crate) vowel_phoneme: Array1, - pub(crate) consonant_phoneme: Array1, - pub(crate) start_accent: Array1, - pub(crate) end_accent: Array1, - pub(crate) start_accent_phrase: Array1, - pub(crate) end_accent_phrase: Array1, + pub(crate) vowel_phoneme_list: Array1, + pub(crate) consonant_phoneme_list: Array1, + pub(crate) start_accent_list: Array1, + pub(crate) end_accent_list: Array1, + pub(crate) start_accent_phrase_list: Array1, + pub(crate) end_accent_phrase_list: Array1, pub(crate) speaker_id: Array1, } -#[derive(TryFromVecOutputTensor)] +#[derive(InferenceOutputSignature)] pub(crate) struct PredictIntonationOutput { pub(crate) f0_list: Array1, } @@ -76,7 +100,7 @@ pub(crate) struct DecodeInput { pub(crate) speaker_id: Array1, } -#[derive(TryFromVecOutputTensor)] +#[derive(InferenceOutputSignature)] pub(crate) struct DecodeOutput { pub(crate) wave: Array1, } diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index 587ce21fa..e1d2a8e3a 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -1,15 +1,18 @@ use std::{ collections::{BTreeMap, HashMap}, + fmt::Display, marker::PhantomData, sync::Arc, }; +use anyhow::bail; use educe::Educe; use enum_map::EnumMap; -use itertools::iproduct; +use itertools::{iproduct, Itertools as _}; use crate::{ error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult}, + infer::ParamInfo, manifest::ModelInnerId, metas::{SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta}, voice_model::{VoiceModel, VoiceModelId}, @@ -249,14 +252,39 @@ impl SessionSet { let mut sessions = model_bytes .iter() .map(|(k, m)| { - let sess = R::new_session(|| model_file::decrypt(m), options[k])?; + let expected_input_param_infos = G::INPUT_PARAM_INFOS[k]; + let expected_output_param_infos = G::OUTPUT_PARAM_INFOS[k]; + + let (sess, actual_input_param_infos, actual_output_param_infos) = + R::new_session(|| model_file::decrypt(m), options[k])?; + + check_param_infos(expected_input_param_infos, &actual_input_param_infos)?; + check_param_infos(expected_output_param_infos, &actual_output_param_infos)?; + Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) }) .collect::>>()?; - Ok(Self(EnumMap::::from_fn(|k| { + return Ok(Self(EnumMap::::from_fn(|k| { sessions.remove(&k.into_usize()).expect("should exist") - }))) + }))); + + fn check_param_infos( + expected: &[ParamInfo], + actual: &[ParamInfo], + ) -> anyhow::Result<()> { + if !(expected.len() == actual.len() + && itertools::zip_eq(expected, actual) + .all(|(expected, actual)| expected.accepts(actual))) + { + bail!( + "expected {{{}}}, got {{{}}}", + expected.iter().join(", "), + actual.iter().join(", "), + ) + } + Ok(()) + } } } diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 5fdbdf6dd..6de29c201 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -87,7 +87,7 @@ impl InferenceCore { .run_session( &model_id, PredictDurationInput { - phoneme: ndarray::arr1(phoneme_vector), + phoneme_list: ndarray::arr1(phoneme_vector), speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), }, ) @@ -127,12 +127,12 @@ impl InferenceCore { &model_id, PredictIntonationInput { length: ndarray::arr0(length as i64), - vowel_phoneme: ndarray::arr1(vowel_phoneme_vector), - consonant_phoneme: ndarray::arr1(consonant_phoneme_vector), - start_accent: ndarray::arr1(start_accent_vector), - end_accent: ndarray::arr1(end_accent_vector), - start_accent_phrase: ndarray::arr1(start_accent_phrase_vector), - end_accent_phrase: ndarray::arr1(end_accent_phrase_vector), + vowel_phoneme_list: ndarray::arr1(vowel_phoneme_vector), + consonant_phoneme_list: ndarray::arr1(consonant_phoneme_vector), + start_accent_list: ndarray::arr1(start_accent_vector), + end_accent_list: ndarray::arr1(end_accent_vector), + start_accent_phrase_list: ndarray::arr1(start_accent_phrase_vector), + end_accent_phrase_list: ndarray::arr1(end_accent_phrase_vector), speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), }, ) diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index dcd235537..a39aa2f20 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -5,13 +5,21 @@ use syn::{ parse::{Parse, ParseStream}, parse_macro_input, spanned::Spanned as _, - Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, + Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, Type, }; #[proc_macro_derive(InferenceGroup)] pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let DeriveInput { ident, .. } = parse_macro_input!(input as DeriveInput); - quote!(impl crate::infer::InferenceGroup for #ident {}).into() + let DeriveInput { + ident, generics, .. + } = parse_macro_input!(input as DeriveInput); + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + quote! { + impl #impl_generics crate::infer::InferenceGroup for #ident #ty_generics #where_clause {} + } + .into() } #[proc_macro_derive(InferenceInputSignature, attributes(input_signature))] @@ -43,7 +51,28 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ .parse_args()?; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - let field_names = struct_field_names(data)?; + + let fields = struct_fields(data)?; + + let param_infos = fields + .iter() + .map(|(name, ty)| { + let name = name.to_string(); + quote! { + crate::infer::ParamInfo { + name: ::std::borrow::Cow::Borrowed(#name), + dt: < + <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::InputScalar + >::KIND, + ndim: < + <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension + >::NDIM, + }, + } + }) + .collect::(); + + let field_names = fields.iter().map(|(name, _)| name); Ok(quote! { impl #impl_generics crate::infer::InferenceInputSignature for #ident #ty_generics @@ -51,6 +80,12 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ { type Signature = #signature; + const PARAM_INFOS: &'static [crate::infer::ParamInfo< + crate::infer::InputScalarKind + >] = &[ + #param_infos + ]; + fn make_run_context( self, sess: &mut R::Session, @@ -80,13 +115,15 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ } } -#[proc_macro_derive(TryFromVecOutputTensor)] -pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - return derive_try_from_vec_any_tensor(&parse_macro_input!(input)) +#[proc_macro_derive(InferenceOutputSignature)] +pub fn derive_inference_output_signature( + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + return derive_inference_output_signature(&parse_macro_input!(input)) .unwrap_or_else(|e| e.to_compile_error()) .into(); - fn derive_try_from_vec_any_tensor( + fn derive_inference_output_signature( input: &DeriveInput, ) -> syn::Result { let DeriveInput { @@ -97,10 +134,41 @@ pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_ma } = input; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - let field_names = struct_field_names(data)?; - let num_fields = field_names.len(); + + let fields = struct_fields(data)?; + let num_fields = fields.len(); + + let param_infos = fields + .iter() + .map(|(name, ty)| { + let name = name.to_string(); + quote! { + crate::infer::ParamInfo { + name: ::std::borrow::Cow::Borrowed(#name), + dt: < + <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::OutputScalar + >::KIND, + ndim: < + <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension + >::NDIM, + }, + } + }) + .collect::(); + + let field_names = fields.iter().map(|(name, _)| name); Ok(quote! { + impl #impl_generics crate::infer::InferenceOutputSignature for #ident #ty_generics + #where_clause + { + const PARAM_INFOS: &'static [crate::infer::ParamInfo< + crate::infer::OutputScalarKind + >] = &[ + #param_infos + ]; + } + impl #impl_generics ::std::convert::TryFrom<::std::vec::Vec> for #ident #ty_generics @@ -133,7 +201,7 @@ pub fn derive_try_from_vec_any_tensor(input: proc_macro::TokenStream) -> proc_ma } } -fn struct_field_names(data: &Data) -> syn::Result> { +fn struct_fields(data: &Data) -> syn::Result> { let fields = match data { Data::Struct(DataStruct { fields: Fields::Named(fields), @@ -153,6 +221,6 @@ fn struct_field_names(data: &Data) -> syn::Result> { Ok(fields .named .iter() - .map(|Field { ident, .. }| ident.as_ref().expect("should be named")) + .map(|Field { ident, ty, .. }| (ident.as_ref().expect("should be named"), ty)) .collect()) } From b7d48f3ab8cc7747f8c72561837ca7b84bcfb872 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Mon, 13 Nov 2023 10:02:51 +0900 Subject: [PATCH 30/47] =?UTF-8?q?signatures=E3=81=AE=E3=83=9E=E3=82=AF?= =?UTF-8?q?=E3=83=AD=E5=8C=96=E3=82=92=E5=AE=8C=E4=BA=86=E3=81=95=E3=81=9B?= =?UTF-8?q?=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 1 + Cargo.toml | 1 + crates/voicevox_core/Cargo.toml | 2 +- crates/voicevox_core/src/infer/signatures.rs | 85 +++------ crates/voicevox_core_macros/Cargo.toml | 3 +- crates/voicevox_core_macros/src/lib.rs | 178 ++++++++++++++++--- 6 files changed, 189 insertions(+), 81 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bd738d44a..c7c85b43b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4408,6 +4408,7 @@ dependencies = [ name = "voicevox_core_macros" version = "0.0.0" dependencies = [ + "indexmap 2.0.0", "proc-macro2", "quote", "syn 2.0.38", diff --git a/Cargo.toml b/Cargo.toml index ddcb356e2..65353e6ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ derive_more = "0.99.17" easy-ext = "1.0.1" fs-err = { version = "2.9.0", features = ["tokio"] } futures = "0.3.26" +indexmap = { version = "2.0.0", features = ["serde"] } itertools = "0.10.5" ndarray = "0.15.6" once_cell = "1.18.0" diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index d6ad96761..763b69605 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -21,7 +21,7 @@ educe = "0.4.23" enum-map = "3.0.0-beta.1" fs-err.workspace = true futures.workspace = true -indexmap = { version = "2.0.0", features = ["serde"] } +indexmap.workspace = true itertools.workspace = true nanoid = "0.4.0" ndarray.workspace = true diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/signatures.rs index c4633658d..332b6942c 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/signatures.rs @@ -1,51 +1,34 @@ -use enum_map::{Enum, EnumMap}; -use macros::{InferenceInputSignature, InferenceOutputSignature}; +use enum_map::Enum; +use macros::{InferenceGroup, InferenceInputSignature, InferenceOutputSignature}; use ndarray::{Array0, Array1, Array2}; -use super::{ - InferenceGroup, InferenceInputSignature as _, InferenceOutputSignature as _, - InferenceSignature, OutputTensor, -}; +use super::{InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor}; -#[derive(Clone, Copy, Enum)] +#[derive(Clone, Copy, Enum, InferenceGroup)] pub(crate) enum InferenceKind { + #[inference_group( + type Input = PredictDurationInput; + type Output = PredictDurationOutput; + )] PredictDuration, - PredictIntonation, - Decode, -} - -// FIXME: ここもマクロ化する -impl InferenceGroup for InferenceKind { - const INPUT_PARAM_INFOS: enum_map::EnumMap< - Self, - &'static [super::ParamInfo], - > = EnumMap::from_array([ - PredictDurationInput::PARAM_INFOS, - PredictIntonationInput::PARAM_INFOS, - DecodeInput::PARAM_INFOS, - ]); - const OUTPUT_PARAM_INFOS: enum_map::EnumMap< - Self, - &'static [super::ParamInfo], - > = EnumMap::from_array([ - PredictDurationOutput::PARAM_INFOS, - PredictIntonationOutput::PARAM_INFOS, - DecodeOutput::PARAM_INFOS, - ]); -} - -pub(crate) enum PredictDuration {} + #[inference_group( + type Input = PredictIntonationInput; + type Output = PredictIntonationOutput; + )] + PredictIntonation, -impl InferenceSignature for PredictDuration { - type Group = InferenceKind; - type Input = PredictDurationInput; - type Output = PredictDurationOutput; - const KIND: InferenceKind = InferenceKind::PredictDuration; + #[inference_group( + type Input = DecodeInput; + type Output = DecodeOutput; + )] + Decode, } #[derive(InferenceInputSignature)] -#[input_signature(Signature = PredictDuration)] +#[inference_input_signature( + type Signature = PredictDuration; +)] pub(crate) struct PredictDurationInput { pub(crate) phoneme_list: Array1, pub(crate) speaker_id: Array1, @@ -56,17 +39,10 @@ pub(crate) struct PredictDurationOutput { pub(crate) phoneme_length: Array1, } -pub(crate) enum PredictIntonation {} - -impl InferenceSignature for PredictIntonation { - type Group = InferenceKind; - type Input = PredictIntonationInput; - type Output = PredictIntonationOutput; - const KIND: InferenceKind = InferenceKind::PredictIntonation; -} - #[derive(InferenceInputSignature)] -#[input_signature(Signature = PredictIntonation)] +#[inference_input_signature( + type Signature = PredictIntonation; +)] pub(crate) struct PredictIntonationInput { pub(crate) length: Array0, pub(crate) vowel_phoneme_list: Array1, @@ -83,17 +59,10 @@ pub(crate) struct PredictIntonationOutput { pub(crate) f0_list: Array1, } -pub(crate) enum Decode {} - -impl InferenceSignature for Decode { - type Group = InferenceKind; - type Input = DecodeInput; - type Output = DecodeOutput; - const KIND: InferenceKind = InferenceKind::Decode; -} - #[derive(InferenceInputSignature)] -#[input_signature(Signature = Decode)] +#[inference_input_signature( + type Signature = Decode; +)] pub(crate) struct DecodeInput { pub(crate) f0: Array2, pub(crate) phoneme: Array2, diff --git a/crates/voicevox_core_macros/Cargo.toml b/crates/voicevox_core_macros/Cargo.toml index 242188b3b..957fa3eb8 100644 --- a/crates/voicevox_core_macros/Cargo.toml +++ b/crates/voicevox_core_macros/Cargo.toml @@ -9,6 +9,7 @@ name = "macros" proc-macro = true [dependencies] +indexmap.workspace = true proc-macro2 = "1.0.69" quote = "1.0.33" -syn = "2.0.38" +syn = { version = "2.0.38", features = ["extra-traits"] } diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index a39aa2f20..07a93c84f 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -1,28 +1,142 @@ #![warn(rust_2018_idioms)] +use indexmap::IndexMap; use quote::quote; use syn::{ parse::{Parse, ParseStream}, parse_macro_input, spanned::Spanned as _, - Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Token, Type, + Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Generics, + ItemType, Type, Variant, }; -#[proc_macro_derive(InferenceGroup)] +#[proc_macro_derive(InferenceGroup, attributes(inference_group))] pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let DeriveInput { - ident, generics, .. - } = parse_macro_input!(input as DeriveInput); + return derive_inference_group(&parse_macro_input!(input)) + .unwrap_or_else(|e| e.to_compile_error()) + .into(); + + fn derive_inference_group(input: &DeriveInput) -> syn::Result { + let DeriveInput { + vis, + ident: group_name, + generics, + data, + .. + } = input; + + deny_generics(generics)?; - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let variants = unit_enum_variants(data)? + .into_iter() + .map(|(attrs, variant_name)| { + let AssocTypes { input, output } = attrs + .iter() + .find(|a| a.path().is_ident("inference_group")) + .ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `#[inference_group(…)]`", + ) + })? + .parse_args()?; + + Ok((variant_name, (input, output))) + }) + .collect::>>()?; - quote! { - impl #impl_generics crate::infer::InferenceGroup for #ident #ty_generics #where_clause {} + let variant_names = &variants.keys().collect::>(); + + let signatures = variants + .iter() + .map(|(variant_name, (input_ty, output_ty))| { + quote! { + #vis enum #variant_name {} + + impl crate::infer::InferenceSignature for #variant_name { + type Group = #group_name; + type Input = #input_ty; + type Output = #output_ty; + const KIND: Self::Group = #group_name :: #variant_name; + } + } + }); + + Ok(quote! { + impl crate::infer::InferenceGroup for #group_name { + const INPUT_PARAM_INFOS: ::enum_map::EnumMap< + Self, + &'static [crate::infer::ParamInfo], + > = ::enum_map::EnumMap::from_array([ + #(<#variant_names as crate::infer::InferenceSignature>::Input::PARAM_INFOS),* + ]); + + const OUTPUT_PARAM_INFOS: ::enum_map::EnumMap< + Self, + &'static [crate::infer::ParamInfo], + > = ::enum_map::EnumMap::from_array([ + #(<#variant_names as crate::infer::InferenceSignature>::Output::PARAM_INFOS),* + ]); + } + + #(#signatures)* + }) + } + + struct AssocTypes { + input: Type, + output: Type, + } + + impl Parse for AssocTypes { + fn parse(stream: ParseStream<'_>) -> syn::Result { + let mut input = None; + let mut output = None; + + while !stream.is_empty() { + let ItemType { + ident, + generics, + ty, + .. + } = stream.parse()?; + + deny_generics(&generics)?; + + *match &*ident.to_string() { + "Input" => &mut input, + "Output" => &mut output, + _ => { + return Err(syn::Error::new( + ident.span(), + "expected `Input` or `Output`", + )) + } + } = Some(*ty); + } + + let input = + input.ok_or_else(|| syn::Error::new(stream.span(), "missing `type Input = …;`"))?; + + let output = output + .ok_or_else(|| syn::Error::new(stream.span(), "missing `type Output = …;`"))?; + + Ok(Self { input, output }) + } + } + + fn deny_generics(generics: &Generics) -> syn::Result<()> { + if !generics.params.is_empty() { + return Err(syn::Error::new(generics.params.span(), "must be empty")); + } + if let Some(where_clause) = &generics.where_clause { + return Err(syn::Error::new(where_clause.span(), "must be empty")); + } + Ok(()) } - .into() } -#[proc_macro_derive(InferenceInputSignature, attributes(input_signature))] +#[proc_macro_derive(InferenceInputSignature, attributes(inference_input_signature))] pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream { return derive_inference_input_signature(&parse_macro_input!(input)) .unwrap_or_else(|e| e.to_compile_error()) @@ -41,11 +155,11 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ let AssocTypeSignature(signature) = attrs .iter() - .find(|a| a.path().is_ident("input_signature")) + .find(|a| a.path().is_ident("inference_input_signature")) .ok_or_else(|| { syn::Error::new( proc_macro2::Span::call_site(), - "missing `#[input_signature(…)]`", + "missing `#[inference_input_signature(…)]`", ) })? .parse_args()?; @@ -100,17 +214,16 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ }) } - struct AssocTypeSignature(syn::Ident); + struct AssocTypeSignature(Type); impl Parse for AssocTypeSignature { fn parse(input: ParseStream<'_>) -> syn::Result { - let key = input.parse::()?; - if key != "Signature" { - return Err(syn::Error::new(key.span(), "expected `Signature`")); + let ItemType { ident, ty, .. } = input.parse()?; + + if ident != "Signature" { + return Err(syn::Error::new(ident.span(), "expected `Signature`")); } - input.parse::()?; - let value = input.parse::()?; - Ok(Self(value)) + Ok(Self(*ty)) } } } @@ -211,10 +324,10 @@ fn struct_fields(data: &Data) -> syn::Result> { return Err(syn::Error::new(fields.span(), "expect named fields")); } Data::Enum(DataEnum { enum_token, .. }) => { - return Err(syn::Error::new(enum_token.span(), "expected an enum")); + return Err(syn::Error::new(enum_token.span(), "expected a struct")); } Data::Union(DataUnion { union_token, .. }) => { - return Err(syn::Error::new(union_token.span(), "expected an enum")); + return Err(syn::Error::new(union_token.span(), "expected a struct")); } }; @@ -224,3 +337,26 @@ fn struct_fields(data: &Data) -> syn::Result> { .map(|Field { ident, ty, .. }| (ident.as_ref().expect("should be named"), ty)) .collect()) } + +fn unit_enum_variants(data: &Data) -> syn::Result> { + let variants = match data { + Data::Struct(DataStruct { struct_token, .. }) => { + return Err(syn::Error::new(struct_token.span(), "expected an enum")); + } + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Union(DataUnion { union_token, .. }) => { + return Err(syn::Error::new(union_token.span(), "expected an enum")); + } + }; + + for Variant { fields, .. } in variants { + if *fields != Fields::Unit { + return Err(syn::Error::new(fields.span(), "must be unit")); + } + } + + Ok(variants + .iter() + .map(|Variant { attrs, ident, .. }| (&**attrs, ident)) + .collect()) +} From b6db1c0bd5f7df402d04ddb4ec1313755bc032e4 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 14 Nov 2023 00:27:12 +0900 Subject: [PATCH 31/47] Minor refactor --- crates/voicevox_core/src/infer.rs | 16 +--------------- crates/voicevox_core/src/infer/status.rs | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index a8577bd24..b53c52488 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -3,10 +3,7 @@ pub(crate) mod runtimes; pub(crate) mod signatures; pub(crate) mod status; -use std::{ - borrow::Cow, - fmt::{self, Debug, Display}, -}; +use std::{borrow::Cow, fmt::Debug}; use derive_new::new; use enum_map::{Enum, EnumMap}; @@ -132,17 +129,6 @@ impl ParamInfo { } } -impl Display for ParamInfo { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}: {}", self.name, self.dt)?; - if let Some(ndim) = self.ndim { - f.write_str(&"[]".repeat(ndim)) - } else { - f.write_str("[]...") - } - } -} - pub(crate) trait ArrayExt { type Scalar; type Dimension: Dimension; diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index e1d2a8e3a..c6a09ed3d 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -277,14 +277,25 @@ impl SessionSet { && itertools::zip_eq(expected, actual) .all(|(expected, actual)| expected.accepts(actual))) { - bail!( - "expected {{{}}}, got {{{}}}", - expected.iter().join(", "), - actual.iter().join(", "), - ) + let expected = display_param_infos(expected); + let actual = display_param_infos(actual); + bail!("expected {{{expected}}}, got {{{actual}}}") } Ok(()) } + + fn display_param_infos(infos: &[ParamInfo]) -> impl Display { + infos + .iter() + .map(|ParamInfo { name, dt, ndim }| { + let brackets = match *ndim { + Some(ndim) => "[]".repeat(ndim), + None => "[]...".to_owned(), + }; + format!("{name}: {dt}{brackets}") + }) + .join(", ") + } } } From 96a93e9f16c8b0b9e3ea01083bfe42790993dcd4 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 14 Nov 2023 23:05:32 +0900 Subject: [PATCH 32/47] =?UTF-8?q?`InferenceGroup`=20=E2=86=92=20`Inference?= =?UTF-8?q?Domain`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 8 +-- .../src/infer/{signatures.rs => domain.rs} | 10 ++-- crates/voicevox_core/src/infer/status.rs | 56 +++++++++---------- crates/voicevox_core/src/inference_core.rs | 2 +- crates/voicevox_core/src/voice_model.rs | 2 +- crates/voicevox_core_macros/src/lib.rs | 20 +++---- 6 files changed, 49 insertions(+), 49 deletions(-) rename crates/voicevox_core/src/infer/{signatures.rs => domain.rs} (90%) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index b53c52488..aa30b3cc9 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -1,6 +1,6 @@ +pub(crate) mod domain; mod model_file; pub(crate) mod runtimes; -pub(crate) mod signatures; pub(crate) mod status; use std::{borrow::Cow, fmt::Debug}; @@ -36,16 +36,16 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } -pub(crate) trait InferenceGroup: Copy + Enum { +pub(crate) trait InferenceDomain: Copy + Enum { const INPUT_PARAM_INFOS: EnumMap]>; const OUTPUT_PARAM_INFOS: EnumMap]>; } pub(crate) trait InferenceSignature: Sized + Send + 'static { - type Group: InferenceGroup; + type Domain: InferenceDomain; type Input: InferenceInputSignature; type Output: InferenceOutputSignature; - const KIND: Self::Group; + const KIND: Self::Domain; } pub(crate) trait InferenceInputSignature: Send + 'static { diff --git a/crates/voicevox_core/src/infer/signatures.rs b/crates/voicevox_core/src/infer/domain.rs similarity index 90% rename from crates/voicevox_core/src/infer/signatures.rs rename to crates/voicevox_core/src/infer/domain.rs index 332b6942c..1b18fdcbd 100644 --- a/crates/voicevox_core/src/infer/signatures.rs +++ b/crates/voicevox_core/src/infer/domain.rs @@ -1,24 +1,24 @@ use enum_map::Enum; -use macros::{InferenceGroup, InferenceInputSignature, InferenceOutputSignature}; +use macros::{InferenceDomain, InferenceInputSignature, InferenceOutputSignature}; use ndarray::{Array0, Array1, Array2}; use super::{InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor}; -#[derive(Clone, Copy, Enum, InferenceGroup)] +#[derive(Clone, Copy, Enum, InferenceDomain)] pub(crate) enum InferenceKind { - #[inference_group( + #[inference_domain( type Input = PredictDurationInput; type Output = PredictDurationOutput; )] PredictDuration, - #[inference_group( + #[inference_domain( type Input = PredictIntonationInput; type Output = PredictIntonationOutput; )] PredictIntonation, - #[inference_group( + #[inference_domain( type Input = DecodeInput; type Output = DecodeOutput; )] diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index c6a09ed3d..f136ecb8b 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -20,17 +20,17 @@ use crate::{ }; use super::{ - model_file, InferenceGroup, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions, - InferenceSignature, + model_file, InferenceDomain, InferenceInputSignature, InferenceRuntime, + InferenceSessionOptions, InferenceSignature, }; -pub(crate) struct Status { - loaded_models: std::sync::Mutex>, - session_options: EnumMap, +pub(crate) struct Status { + loaded_models: std::sync::Mutex>, + session_options: EnumMap, } -impl Status { - pub fn new(session_options: EnumMap) -> Self { +impl Status { + pub fn new(session_options: EnumMap) -> Self { Self { loaded_models: Default::default(), session_options, @@ -40,7 +40,7 @@ impl Status { pub async fn load_model( &self, model: &VoiceModel, - model_bytes: &EnumMap>, + model_bytes: &EnumMap>, ) -> Result<()> { self.loaded_models .lock() @@ -100,7 +100,7 @@ impl Status { ) -> Result<::Output> where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { let sess = self.loaded_models.lock().unwrap().get(model_id); @@ -114,18 +114,18 @@ impl Status { /// /// この構造体のメソッドは、すべて一瞬で完了すべきである。 #[derive(Educe)] -#[educe(Default(bound = "R: InferenceRuntime, G: InferenceGroup"))] -struct LoadedModels( - BTreeMap>, +#[educe(Default(bound = "R: InferenceRuntime, D: InferenceDomain"))] +struct LoadedModels( + BTreeMap>, ); -struct LoadedModel { +struct LoadedModel { model_inner_ids: BTreeMap, metas: VoiceModelMeta, - session_set: SessionSet, + session_set: SessionSet, } -impl LoadedModels { +impl LoadedModels { fn metas(&self) -> VoiceModelMeta { self.0 .values() @@ -164,7 +164,7 @@ impl LoadedModels { fn get(&self, model_id: &VoiceModelId) -> SessionCell where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { self.0[model_id].session_set.get() } @@ -207,7 +207,7 @@ impl LoadedModels { Ok(()) } - fn insert(&mut self, model: &VoiceModel, session_set: SessionSet) -> Result<()> { + fn insert(&mut self, model: &VoiceModel, session_set: SessionSet) -> Result<()> { self.ensure_acceptable(model)?; let prev = self.0.insert( @@ -240,20 +240,20 @@ impl LoadedModels { } } -struct SessionSet( - EnumMap>>, +struct SessionSet( + EnumMap>>, ); -impl SessionSet { +impl SessionSet { fn new( - model_bytes: &EnumMap>, - options: &EnumMap, + model_bytes: &EnumMap>, + options: &EnumMap, ) -> anyhow::Result { let mut sessions = model_bytes .iter() .map(|(k, m)| { - let expected_input_param_infos = G::INPUT_PARAM_INFOS[k]; - let expected_output_param_infos = G::OUTPUT_PARAM_INFOS[k]; + let expected_input_param_infos = D::INPUT_PARAM_INFOS[k]; + let expected_output_param_infos = D::OUTPUT_PARAM_INFOS[k]; let (sess, actual_input_param_infos, actual_output_param_infos) = R::new_session(|| model_file::decrypt(m), options[k])?; @@ -265,7 +265,7 @@ impl SessionSet { }) .collect::>>()?; - return Ok(Self(EnumMap::::from_fn(|k| { + return Ok(Self(EnumMap::::from_fn(|k| { sessions.remove(&k.into_usize()).expect("should exist") }))); @@ -299,11 +299,11 @@ impl SessionSet { } } -impl SessionSet { +impl SessionSet { fn get(&self) -> SessionCell where I: InferenceInputSignature, - I::Signature: InferenceSignature, + I::Signature: InferenceSignature, { SessionCell { inner: self.0[I::Signature::KIND].clone(), @@ -334,7 +334,7 @@ mod tests { use rstest::rstest; use crate::{ - infer::signatures::InferenceKind, macros::tests::assert_debug_fmt_eq, + infer::domain::InferenceKind, macros::tests::assert_debug_fmt_eq, synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file, }; diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 6de29c201..71628ef96 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,7 +1,7 @@ use enum_map::enum_map; use crate::infer::{ - signatures::{ + domain::{ DecodeInput, DecodeOutput, InferenceKind, PredictDurationInput, PredictDurationOutput, PredictIntonationInput, PredictIntonationOutput, }, diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 5b75bcacf..96494eab5 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -4,7 +4,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::signatures::InferenceKind; +use crate::infer::domain::InferenceKind; use std::{ collections::{BTreeMap, HashMap}, io, diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index 07a93c84f..4fc3da39b 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -10,16 +10,16 @@ use syn::{ ItemType, Type, Variant, }; -#[proc_macro_derive(InferenceGroup, attributes(inference_group))] -pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - return derive_inference_group(&parse_macro_input!(input)) +#[proc_macro_derive(InferenceDomain, attributes(inference_domain))] +pub fn derive_inference_domain(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + return derive_inference_domain(&parse_macro_input!(input)) .unwrap_or_else(|e| e.to_compile_error()) .into(); - fn derive_inference_group(input: &DeriveInput) -> syn::Result { + fn derive_inference_domain(input: &DeriveInput) -> syn::Result { let DeriveInput { vis, - ident: group_name, + ident: domain_name, generics, data, .. @@ -32,11 +32,11 @@ pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::Tok .map(|(attrs, variant_name)| { let AssocTypes { input, output } = attrs .iter() - .find(|a| a.path().is_ident("inference_group")) + .find(|a| a.path().is_ident("inference_domain")) .ok_or_else(|| { syn::Error::new( proc_macro2::Span::call_site(), - "missing `#[inference_group(…)]`", + "missing `#[inference_domain(…)]`", ) })? .parse_args()?; @@ -54,16 +54,16 @@ pub fn derive_inference_group(input: proc_macro::TokenStream) -> proc_macro::Tok #vis enum #variant_name {} impl crate::infer::InferenceSignature for #variant_name { - type Group = #group_name; + type Domain = #domain_name; type Input = #input_ty; type Output = #output_ty; - const KIND: Self::Group = #group_name :: #variant_name; + const KIND: Self::Domain = #domain_name :: #variant_name; } } }); Ok(quote! { - impl crate::infer::InferenceGroup for #group_name { + impl crate::infer::InferenceDomain for #domain_name { const INPUT_PARAM_INFOS: ::enum_map::EnumMap< Self, &'static [crate::infer::ParamInfo], From 59d87797ca64e39375c48cb4413a6be3f8d1cbfa Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 14 Nov 2023 23:19:37 +0900 Subject: [PATCH 33/47] Minor refactor --- .../src/inference_domain.rs | 337 +++++++++++++++++ crates/voicevox_core_macros/src/lib.rs | 355 +----------------- 2 files changed, 348 insertions(+), 344 deletions(-) create mode 100644 crates/voicevox_core_macros/src/inference_domain.rs diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs new file mode 100644 index 000000000..73b5ac1cc --- /dev/null +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -0,0 +1,337 @@ +use indexmap::IndexMap; +use quote::quote; +use syn::{ + parse::{Parse, ParseStream}, + spanned::Spanned as _, + Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Generics, + ItemType, Type, Variant, +}; + +pub(crate) fn derive_inference_domain( + input: &DeriveInput, +) -> syn::Result { + let DeriveInput { + vis, + ident: domain_name, + generics, + data, + .. + } = input; + + deny_generics(generics)?; + + let variants = unit_enum_variants(data)? + .into_iter() + .map(|(attrs, variant_name)| { + let AssocTypes { input, output } = attrs + .iter() + .find(|a| a.path().is_ident("inference_domain")) + .ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `#[inference_domain(…)]`", + ) + })? + .parse_args()?; + + Ok((variant_name, (input, output))) + }) + .collect::>>()?; + + let variant_names = &variants.keys().collect::>(); + + let signatures = variants + .iter() + .map(|(variant_name, (input_ty, output_ty))| { + quote! { + #vis enum #variant_name {} + + impl crate::infer::InferenceSignature for #variant_name { + type Domain = #domain_name; + type Input = #input_ty; + type Output = #output_ty; + const KIND: Self::Domain = #domain_name :: #variant_name; + } + } + }); + + return Ok(quote! { + impl crate::infer::InferenceDomain for #domain_name { + const INPUT_PARAM_INFOS: ::enum_map::EnumMap< + Self, + &'static [crate::infer::ParamInfo], + > = ::enum_map::EnumMap::from_array([ + #(<#variant_names as crate::infer::InferenceSignature>::Input::PARAM_INFOS),* + ]); + + const OUTPUT_PARAM_INFOS: ::enum_map::EnumMap< + Self, + &'static [crate::infer::ParamInfo], + > = ::enum_map::EnumMap::from_array([ + #(<#variant_names as crate::infer::InferenceSignature>::Output::PARAM_INFOS),* + ]); + } + + #(#signatures)* + }); + + struct AssocTypes { + input: Type, + output: Type, + } + + impl Parse for AssocTypes { + fn parse(stream: ParseStream<'_>) -> syn::Result { + let mut input = None; + let mut output = None; + + while !stream.is_empty() { + let ItemType { + ident, + generics, + ty, + .. + } = stream.parse()?; + + deny_generics(&generics)?; + + *match &*ident.to_string() { + "Input" => &mut input, + "Output" => &mut output, + _ => { + return Err(syn::Error::new( + ident.span(), + "expected `Input` or `Output`", + )) + } + } = Some(*ty); + } + + let input = + input.ok_or_else(|| syn::Error::new(stream.span(), "missing `type Input = …;`"))?; + + let output = output + .ok_or_else(|| syn::Error::new(stream.span(), "missing `type Output = …;`"))?; + + Ok(Self { input, output }) + } + } + + fn deny_generics(generics: &Generics) -> syn::Result<()> { + if !generics.params.is_empty() { + return Err(syn::Error::new(generics.params.span(), "must be empty")); + } + if let Some(where_clause) = &generics.where_clause { + return Err(syn::Error::new(where_clause.span(), "must be empty")); + } + Ok(()) + } +} + +pub(crate) fn derive_inference_input_signature( + input: &DeriveInput, +) -> syn::Result { + let DeriveInput { + attrs, + ident, + generics, + data, + .. + } = input; + + let AssocTypeSignature(signature) = attrs + .iter() + .find(|a| a.path().is_ident("inference_input_signature")) + .ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `#[inference_input_signature(…)]`", + ) + })? + .parse_args()?; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let fields = struct_fields(data)?; + + let param_infos = fields + .iter() + .map(|(name, ty)| { + let name = name.to_string(); + quote! { + crate::infer::ParamInfo { + name: ::std::borrow::Cow::Borrowed(#name), + dt: < + <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::InputScalar + >::KIND, + ndim: < + <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension + >::NDIM, + }, + } + }) + .collect::(); + + let field_names = fields.iter().map(|(name, _)| name); + + return Ok(quote! { + impl #impl_generics crate::infer::InferenceInputSignature for #ident #ty_generics + #where_clause + { + type Signature = #signature; + + const PARAM_INFOS: &'static [crate::infer::ParamInfo< + crate::infer::InputScalarKind + >] = &[ + #param_infos + ]; + + fn make_run_context( + self, + sess: &mut R::Session, + ) -> R::RunContext<'_> { + let mut ctx = as ::std::convert::From<_>>::from(sess); + #( + R::push_input(self.#field_names, &mut ctx); + )* + ctx + } + } + }); + + struct AssocTypeSignature(Type); + + impl Parse for AssocTypeSignature { + fn parse(input: ParseStream<'_>) -> syn::Result { + let ItemType { ident, ty, .. } = input.parse()?; + + if ident != "Signature" { + return Err(syn::Error::new(ident.span(), "expected `Signature`")); + } + Ok(Self(*ty)) + } + } +} + +pub(crate) fn derive_inference_output_signature( + input: &DeriveInput, +) -> syn::Result { + let DeriveInput { + ident, + generics, + data, + .. + } = input; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let fields = struct_fields(data)?; + let num_fields = fields.len(); + + let param_infos = fields + .iter() + .map(|(name, ty)| { + let name = name.to_string(); + quote! { + crate::infer::ParamInfo { + name: ::std::borrow::Cow::Borrowed(#name), + dt: < + <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::OutputScalar + >::KIND, + ndim: < + <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension + >::NDIM, + }, + } + }) + .collect::(); + + let field_names = fields.iter().map(|(name, _)| name); + + Ok(quote! { + impl #impl_generics crate::infer::InferenceOutputSignature for #ident #ty_generics + #where_clause + { + const PARAM_INFOS: &'static [crate::infer::ParamInfo< + crate::infer::OutputScalarKind + >] = &[ + #param_infos + ]; + } + + impl #impl_generics ::std::convert::TryFrom<::std::vec::Vec> + for #ident #ty_generics + #where_clause + { + type Error = ::anyhow::Error; + + fn try_from( + tensors: ::std::vec::Vec, + ) -> ::std::result::Result { + ::anyhow::ensure!( + tensors.len() == #num_fields, + "expected {} tensor(s), got {}", + #num_fields, + tensors.len(), + ); + + let tensors = &mut ::std::iter::IntoIterator::into_iter(tensors); + ::std::result::Result::Ok(Self { + #( + #field_names: ::std::convert::TryInto::try_into( + ::std::iter::Iterator::next(tensors) + .expect("the length should have been checked"), + )?, + )* + }) + } + } + }) +} + +fn struct_fields(data: &Data) -> syn::Result> { + let fields = match data { + Data::Struct(DataStruct { + fields: Fields::Named(fields), + .. + }) => fields, + Data::Struct(DataStruct { fields, .. }) => { + return Err(syn::Error::new(fields.span(), "expect named fields")); + } + Data::Enum(DataEnum { enum_token, .. }) => { + return Err(syn::Error::new(enum_token.span(), "expected a struct")); + } + Data::Union(DataUnion { union_token, .. }) => { + return Err(syn::Error::new(union_token.span(), "expected a struct")); + } + }; + + Ok(fields + .named + .iter() + .map(|Field { ident, ty, .. }| (ident.as_ref().expect("should be named"), ty)) + .collect()) +} + +fn unit_enum_variants(data: &Data) -> syn::Result> { + let variants = match data { + Data::Struct(DataStruct { struct_token, .. }) => { + return Err(syn::Error::new(struct_token.span(), "expected an enum")); + } + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Union(DataUnion { union_token, .. }) => { + return Err(syn::Error::new(union_token.span(), "expected an enum")); + } + }; + + for Variant { fields, .. } in variants { + if *fields != Fields::Unit { + return Err(syn::Error::new(fields.span(), "must be unit")); + } + } + + Ok(variants + .iter() + .map(|Variant { attrs, ident, .. }| (&**attrs, ident)) + .collect()) +} diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index 4fc3da39b..fee3467c7 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -1,362 +1,29 @@ #![warn(rust_2018_idioms)] -use indexmap::IndexMap; -use quote::quote; -use syn::{ - parse::{Parse, ParseStream}, - parse_macro_input, - spanned::Spanned as _, - Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Generics, - ItemType, Type, Variant, -}; +mod inference_domain; + +use syn::parse_macro_input; #[proc_macro_derive(InferenceDomain, attributes(inference_domain))] pub fn derive_inference_domain(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - return derive_inference_domain(&parse_macro_input!(input)) - .unwrap_or_else(|e| e.to_compile_error()) - .into(); - - fn derive_inference_domain(input: &DeriveInput) -> syn::Result { - let DeriveInput { - vis, - ident: domain_name, - generics, - data, - .. - } = input; - - deny_generics(generics)?; - - let variants = unit_enum_variants(data)? - .into_iter() - .map(|(attrs, variant_name)| { - let AssocTypes { input, output } = attrs - .iter() - .find(|a| a.path().is_ident("inference_domain")) - .ok_or_else(|| { - syn::Error::new( - proc_macro2::Span::call_site(), - "missing `#[inference_domain(…)]`", - ) - })? - .parse_args()?; - - Ok((variant_name, (input, output))) - }) - .collect::>>()?; - - let variant_names = &variants.keys().collect::>(); - - let signatures = variants - .iter() - .map(|(variant_name, (input_ty, output_ty))| { - quote! { - #vis enum #variant_name {} - - impl crate::infer::InferenceSignature for #variant_name { - type Domain = #domain_name; - type Input = #input_ty; - type Output = #output_ty; - const KIND: Self::Domain = #domain_name :: #variant_name; - } - } - }); - - Ok(quote! { - impl crate::infer::InferenceDomain for #domain_name { - const INPUT_PARAM_INFOS: ::enum_map::EnumMap< - Self, - &'static [crate::infer::ParamInfo], - > = ::enum_map::EnumMap::from_array([ - #(<#variant_names as crate::infer::InferenceSignature>::Input::PARAM_INFOS),* - ]); - - const OUTPUT_PARAM_INFOS: ::enum_map::EnumMap< - Self, - &'static [crate::infer::ParamInfo], - > = ::enum_map::EnumMap::from_array([ - #(<#variant_names as crate::infer::InferenceSignature>::Output::PARAM_INFOS),* - ]); - } - - #(#signatures)* - }) - } - - struct AssocTypes { - input: Type, - output: Type, - } - - impl Parse for AssocTypes { - fn parse(stream: ParseStream<'_>) -> syn::Result { - let mut input = None; - let mut output = None; - - while !stream.is_empty() { - let ItemType { - ident, - generics, - ty, - .. - } = stream.parse()?; - - deny_generics(&generics)?; - - *match &*ident.to_string() { - "Input" => &mut input, - "Output" => &mut output, - _ => { - return Err(syn::Error::new( - ident.span(), - "expected `Input` or `Output`", - )) - } - } = Some(*ty); - } - - let input = - input.ok_or_else(|| syn::Error::new(stream.span(), "missing `type Input = …;`"))?; - - let output = output - .ok_or_else(|| syn::Error::new(stream.span(), "missing `type Output = …;`"))?; - - Ok(Self { input, output }) - } - } - - fn deny_generics(generics: &Generics) -> syn::Result<()> { - if !generics.params.is_empty() { - return Err(syn::Error::new(generics.params.span(), "must be empty")); - } - if let Some(where_clause) = &generics.where_clause { - return Err(syn::Error::new(where_clause.span(), "must be empty")); - } - Ok(()) - } + let input = &parse_macro_input!(input); + from_syn(inference_domain::derive_inference_domain(input)) } #[proc_macro_derive(InferenceInputSignature, attributes(inference_input_signature))] pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - return derive_inference_input_signature(&parse_macro_input!(input)) - .unwrap_or_else(|e| e.to_compile_error()) - .into(); - - fn derive_inference_input_signature( - input: &DeriveInput, - ) -> syn::Result { - let DeriveInput { - attrs, - ident, - generics, - data, - .. - } = input; - - let AssocTypeSignature(signature) = attrs - .iter() - .find(|a| a.path().is_ident("inference_input_signature")) - .ok_or_else(|| { - syn::Error::new( - proc_macro2::Span::call_site(), - "missing `#[inference_input_signature(…)]`", - ) - })? - .parse_args()?; - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let fields = struct_fields(data)?; - - let param_infos = fields - .iter() - .map(|(name, ty)| { - let name = name.to_string(); - quote! { - crate::infer::ParamInfo { - name: ::std::borrow::Cow::Borrowed(#name), - dt: < - <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::InputScalar - >::KIND, - ndim: < - <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension - >::NDIM, - }, - } - }) - .collect::(); - - let field_names = fields.iter().map(|(name, _)| name); - - Ok(quote! { - impl #impl_generics crate::infer::InferenceInputSignature for #ident #ty_generics - #where_clause - { - type Signature = #signature; - - const PARAM_INFOS: &'static [crate::infer::ParamInfo< - crate::infer::InputScalarKind - >] = &[ - #param_infos - ]; - - fn make_run_context( - self, - sess: &mut R::Session, - ) -> R::RunContext<'_> { - let mut ctx = as ::std::convert::From<_>>::from(sess); - #( - R::push_input(self.#field_names, &mut ctx); - )* - ctx - } - } - }) - } - - struct AssocTypeSignature(Type); - - impl Parse for AssocTypeSignature { - fn parse(input: ParseStream<'_>) -> syn::Result { - let ItemType { ident, ty, .. } = input.parse()?; - - if ident != "Signature" { - return Err(syn::Error::new(ident.span(), "expected `Signature`")); - } - Ok(Self(*ty)) - } - } + let input = &parse_macro_input!(input); + from_syn(inference_domain::derive_inference_input_signature(input)) } #[proc_macro_derive(InferenceOutputSignature)] pub fn derive_inference_output_signature( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - return derive_inference_output_signature(&parse_macro_input!(input)) - .unwrap_or_else(|e| e.to_compile_error()) - .into(); - - fn derive_inference_output_signature( - input: &DeriveInput, - ) -> syn::Result { - let DeriveInput { - ident, - generics, - data, - .. - } = input; - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let fields = struct_fields(data)?; - let num_fields = fields.len(); - - let param_infos = fields - .iter() - .map(|(name, ty)| { - let name = name.to_string(); - quote! { - crate::infer::ParamInfo { - name: ::std::borrow::Cow::Borrowed(#name), - dt: < - <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::OutputScalar - >::KIND, - ndim: < - <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension - >::NDIM, - }, - } - }) - .collect::(); - - let field_names = fields.iter().map(|(name, _)| name); - - Ok(quote! { - impl #impl_generics crate::infer::InferenceOutputSignature for #ident #ty_generics - #where_clause - { - const PARAM_INFOS: &'static [crate::infer::ParamInfo< - crate::infer::OutputScalarKind - >] = &[ - #param_infos - ]; - } - - impl #impl_generics - ::std::convert::TryFrom<::std::vec::Vec> - for #ident #ty_generics - #where_clause - { - type Error = ::anyhow::Error; - - fn try_from( - tensors: ::std::vec::Vec, - ) -> ::std::result::Result { - ::anyhow::ensure!( - tensors.len() == #num_fields, - "expected {} tensor(s), got {}", - #num_fields, - tensors.len(), - ); - - let tensors = &mut ::std::iter::IntoIterator::into_iter(tensors); - ::std::result::Result::Ok(Self { - #( - #field_names: ::std::convert::TryInto::try_into( - ::std::iter::Iterator::next(tensors) - .expect("the length should have been checked"), - )?, - )* - }) - } - } - }) - } + let input = &parse_macro_input!(input); + from_syn(inference_domain::derive_inference_output_signature(input)) } -fn struct_fields(data: &Data) -> syn::Result> { - let fields = match data { - Data::Struct(DataStruct { - fields: Fields::Named(fields), - .. - }) => fields, - Data::Struct(DataStruct { fields, .. }) => { - return Err(syn::Error::new(fields.span(), "expect named fields")); - } - Data::Enum(DataEnum { enum_token, .. }) => { - return Err(syn::Error::new(enum_token.span(), "expected a struct")); - } - Data::Union(DataUnion { union_token, .. }) => { - return Err(syn::Error::new(union_token.span(), "expected a struct")); - } - }; - - Ok(fields - .named - .iter() - .map(|Field { ident, ty, .. }| (ident.as_ref().expect("should be named"), ty)) - .collect()) -} - -fn unit_enum_variants(data: &Data) -> syn::Result> { - let variants = match data { - Data::Struct(DataStruct { struct_token, .. }) => { - return Err(syn::Error::new(struct_token.span(), "expected an enum")); - } - Data::Enum(DataEnum { variants, .. }) => variants, - Data::Union(DataUnion { union_token, .. }) => { - return Err(syn::Error::new(union_token.span(), "expected an enum")); - } - }; - - for Variant { fields, .. } in variants { - if *fields != Fields::Unit { - return Err(syn::Error::new(fields.span(), "must be unit")); - } - } - - Ok(variants - .iter() - .map(|Variant { attrs, ident, .. }| (&**attrs, ident)) - .collect()) +fn from_syn(result: syn::Result) -> proc_macro::TokenStream { + result.unwrap_or_else(|e| e.to_compile_error()).into() } From d0dc56ff96cec6652540016343aa9c47ffc1c819 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 14 Nov 2023 23:31:18 +0900 Subject: [PATCH 34/47] =?UTF-8?q?`InferenceDomain::{INPUT,OUTPUT}=5FPARAM?= =?UTF-8?q?=5FINFOS`=E3=82=92=E7=B5=B1=E5=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 10 ++++++++-- crates/voicevox_core/src/infer/status.rs | 3 +-- .../src/inference_domain.rs | 19 +++++++++---------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index aa30b3cc9..6b94ade80 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -37,8 +37,14 @@ pub(crate) trait InferenceRuntime: 'static { } pub(crate) trait InferenceDomain: Copy + Enum { - const INPUT_PARAM_INFOS: EnumMap]>; - const OUTPUT_PARAM_INFOS: EnumMap]>; + #[allow(clippy::type_complexity)] + const PARAM_INFOS: EnumMap< + Self, + ( + &'static [ParamInfo], + &'static [ParamInfo], + ), + >; } pub(crate) trait InferenceSignature: Sized + Send + 'static { diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index f136ecb8b..5740c719a 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -252,8 +252,7 @@ impl SessionSet { let mut sessions = model_bytes .iter() .map(|(k, m)| { - let expected_input_param_infos = D::INPUT_PARAM_INFOS[k]; - let expected_output_param_infos = D::OUTPUT_PARAM_INFOS[k]; + let (expected_input_param_infos, expected_output_param_infos) = D::PARAM_INFOS[k]; let (sess, actual_input_param_infos, actual_output_param_infos) = R::new_session(|| model_file::decrypt(m), options[k])?; diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs index 73b5ac1cc..69bc809a5 100644 --- a/crates/voicevox_core_macros/src/inference_domain.rs +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -57,18 +57,17 @@ pub(crate) fn derive_inference_domain( return Ok(quote! { impl crate::infer::InferenceDomain for #domain_name { - const INPUT_PARAM_INFOS: ::enum_map::EnumMap< + const PARAM_INFOS: ::enum_map::EnumMap< Self, - &'static [crate::infer::ParamInfo], + ( + &'static [crate::infer::ParamInfo], + &'static [crate::infer::ParamInfo], + ), > = ::enum_map::EnumMap::from_array([ - #(<#variant_names as crate::infer::InferenceSignature>::Input::PARAM_INFOS),* - ]); - - const OUTPUT_PARAM_INFOS: ::enum_map::EnumMap< - Self, - &'static [crate::infer::ParamInfo], - > = ::enum_map::EnumMap::from_array([ - #(<#variant_names as crate::infer::InferenceSignature>::Output::PARAM_INFOS),* + #(( + <#variant_names as crate::infer::InferenceSignature>::Input::PARAM_INFOS, + <#variant_names as crate::infer::InferenceSignature>::Output::PARAM_INFOS + )),* ]); } From c654cd1f2d10acbc3f68aa4e6694525d4566beb2 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 14 Nov 2023 23:35:06 +0900 Subject: [PATCH 35/47] =?UTF-8?q?`InferenceDomain::PARAM=5FINFOS`=E3=81=AB?= =?UTF-8?q?docstring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 6b94ade80..3fc6db2cd 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -37,6 +37,9 @@ pub(crate) trait InferenceRuntime: 'static { } pub(crate) trait InferenceDomain: Copy + Enum { + /// `{InferenceInputSignature,InferenceOutputSignature}::PARAM_INFOS`を集めたもの。 + /// + /// マクロ(voicevox_core_macros)で実装される前提。 #[allow(clippy::type_complexity)] const PARAM_INFOS: EnumMap< Self, From 868d3f61e76cf04147674c86d60142c7fec0ca6b Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 14 Nov 2023 23:46:06 +0900 Subject: [PATCH 36/47] =?UTF-8?q?voicevox=5Fcore=5Fmacros=E3=81=ABdocstrin?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core_macros/src/lib.rs | 60 ++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index fee3467c7..e12cdc27e 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -4,18 +4,78 @@ mod inference_domain; use syn::parse_macro_input; +/// `voicevox_core`内で、`crate::infer::InferenceDomain`を実装する。 +/// +/// # Example +/// +/// ``` +/// use enum_map::Enum; +/// use macros::InferenceDomain; +/// +/// #[derive(Clone, Copy, Enum, InferenceDomain)] +/// pub(crate) enum InferenceKind { +/// #[inference_domain( +/// type Input = PredictDurationInput; +/// type Output = PredictDurationOutput; +/// )] +/// PredictDuration, +/// +/// #[inference_domain( +/// type Input = PredictIntonationInput; +/// type Output = PredictIntonationOutput; +/// )] +/// PredictIntonation, +/// +/// #[inference_domain( +/// type Input = DecodeInput; +/// type Output = DecodeOutput; +/// )] +/// Decode, +/// } +/// ``` +#[cfg(not(doctest))] #[proc_macro_derive(InferenceDomain, attributes(inference_domain))] pub fn derive_inference_domain(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = &parse_macro_input!(input); from_syn(inference_domain::derive_inference_domain(input)) } +/// `voicevox_core`内で、`crate::infer::InferenceInputSignature`を実装する。 +/// +/// # Example +/// +/// ``` +/// use macros::InferenceInputSignature; +/// +/// #[derive(InferenceInputSignature)] +/// #[inference_input_signature( +/// type Signature = PredictDuration; +/// )] +/// pub(crate) struct PredictDurationInput { +/// pub(crate) phoneme_list: Array1, +/// pub(crate) speaker_id: Array1, +/// } +/// ``` +#[cfg(not(doctest))] #[proc_macro_derive(InferenceInputSignature, attributes(inference_input_signature))] pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = &parse_macro_input!(input); from_syn(inference_domain::derive_inference_input_signature(input)) } +/// `voicevox_core`内で、`crate::infer::InferenceInputSignature`を実装する。 +/// +/// # Example +/// +/// ``` +/// use macros::InferenceOutputSignature; +/// +/// #[derive(InferenceOutputSignature)] +/// pub(crate) struct PredictDurationOutput { +/// pub(crate) phoneme_length: Array1, +/// } +/// ``` +#[cfg(not(doctest))] #[proc_macro_derive(InferenceOutputSignature)] pub fn derive_inference_output_signature( input: proc_macro::TokenStream, From 099879375df892544a95825832197b51a3863ba0 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 14 Nov 2023 23:57:28 +0900 Subject: [PATCH 37/47] =?UTF-8?q?`sealed::InputScalar`=E3=81=ABFIXME?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 3fc6db2cd..c93bbb925 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -164,6 +164,8 @@ pub(crate) enum ExtractError { #[error("不正なモデルファイルです")] pub(crate) struct DecryptModelError; +// FIXME: `onnxruntime::TypeToTensorElementDataType`に依存する代わりに、`InputScalar`から`runtimes` +// まではvisitor patternでつなぐ mod sealed { pub(crate) trait InputScalar: OnnxruntimeInputScalar {} From 75fd7acaa6a4d03f67552ca3d80e0253da88b8fd Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Wed, 15 Nov 2023 02:27:19 +0900 Subject: [PATCH 38/47] =?UTF-8?q?"Domain"=E3=81=A8"Operation"=E3=81=AB?= =?UTF-8?q?=E5=88=86=E9=9B=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 8 ++- crates/voicevox_core/src/infer/domain.rs | 25 ++++++--- crates/voicevox_core/src/infer/status.rs | 53 ++++++++++--------- crates/voicevox_core/src/inference_core.rs | 13 ++--- crates/voicevox_core/src/voice_model.rs | 4 +- .../src/inference_domain.rs | 41 +++++++++++--- crates/voicevox_core_macros/src/lib.rs | 25 +++++---- 7 files changed, 111 insertions(+), 58 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index c93bbb925..e3698935e 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -36,7 +36,11 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } -pub(crate) trait InferenceDomain: Copy + Enum { +pub(crate) trait InferenceDomain { + type Operation: InferenceOperation; +} + +pub(crate) trait InferenceOperation: Copy + Enum { /// `{InferenceInputSignature,InferenceOutputSignature}::PARAM_INFOS`を集めたもの。 /// /// マクロ(voicevox_core_macros)で実装される前提。 @@ -54,7 +58,7 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static { type Domain: InferenceDomain; type Input: InferenceInputSignature; type Output: InferenceOutputSignature; - const KIND: Self::Domain; + const OPERATION: ::Operation; } pub(crate) trait InferenceInputSignature: Send + 'static { diff --git a/crates/voicevox_core/src/infer/domain.rs b/crates/voicevox_core/src/infer/domain.rs index 1b18fdcbd..1decca75e 100644 --- a/crates/voicevox_core/src/infer/domain.rs +++ b/crates/voicevox_core/src/infer/domain.rs @@ -1,24 +1,35 @@ use enum_map::Enum; -use macros::{InferenceDomain, InferenceInputSignature, InferenceOutputSignature}; +use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature}; use ndarray::{Array0, Array1, Array2}; -use super::{InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor}; +use super::{ + InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor, +}; -#[derive(Clone, Copy, Enum, InferenceDomain)] -pub(crate) enum InferenceKind { - #[inference_domain( +pub(crate) enum InferenceDomainImpl {} + +impl InferenceDomain for InferenceDomainImpl { + type Operation = InferenceOperationKind; +} + +#[derive(Clone, Copy, Enum, InferenceOperation)] +#[inference_operation( + type Domain = InferenceDomainImpl; +)] +pub(crate) enum InferenceOperationKind { + #[inference_operation( type Input = PredictDurationInput; type Output = PredictDurationOutput; )] PredictDuration, - #[inference_domain( + #[inference_operation( type Input = PredictIntonationInput; type Output = PredictIntonationOutput; )] PredictIntonation, - #[inference_domain( + #[inference_operation( type Input = DecodeInput; type Output = DecodeOutput; )] diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index 5740c719a..72e98bd9f 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -7,12 +7,12 @@ use std::{ use anyhow::bail; use educe::Educe; -use enum_map::EnumMap; +use enum_map::{Enum as _, EnumMap}; use itertools::{iproduct, Itertools as _}; use crate::{ error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult}, - infer::ParamInfo, + infer::{InferenceOperation, ParamInfo}, manifest::ModelInnerId, metas::{SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta}, voice_model::{VoiceModel, VoiceModelId}, @@ -26,11 +26,11 @@ use super::{ pub(crate) struct Status { loaded_models: std::sync::Mutex>, - session_options: EnumMap, + session_options: EnumMap, } impl Status { - pub fn new(session_options: EnumMap) -> Self { + pub fn new(session_options: EnumMap) -> Self { Self { loaded_models: Default::default(), session_options, @@ -40,7 +40,7 @@ impl Status { pub async fn load_model( &self, model: &VoiceModel, - model_bytes: &EnumMap>, + model_bytes: &EnumMap>, ) -> Result<()> { self.loaded_models .lock() @@ -241,30 +241,31 @@ impl LoadedModels { } struct SessionSet( - EnumMap>>, + EnumMap>>, ); impl SessionSet { fn new( - model_bytes: &EnumMap>, - options: &EnumMap, + model_bytes: &EnumMap>, + options: &EnumMap, ) -> anyhow::Result { let mut sessions = model_bytes .iter() - .map(|(k, m)| { - let (expected_input_param_infos, expected_output_param_infos) = D::PARAM_INFOS[k]; + .map(|(op, model_bytes)| { + let (expected_input_param_infos, expected_output_param_infos) = + ::PARAM_INFOS[op]; let (sess, actual_input_param_infos, actual_output_param_infos) = - R::new_session(|| model_file::decrypt(m), options[k])?; + R::new_session(|| model_file::decrypt(model_bytes), options[op])?; check_param_infos(expected_input_param_infos, &actual_input_param_infos)?; check_param_infos(expected_output_param_infos, &actual_output_param_infos)?; - Ok((k.into_usize(), std::sync::Mutex::new(sess).into())) + Ok((op.into_usize(), std::sync::Mutex::new(sess).into())) }) .collect::>>()?; - return Ok(Self(EnumMap::::from_fn(|k| { + return Ok(Self(EnumMap::::from_fn(|k| { sessions.remove(&k.into_usize()).expect("should exist") }))); @@ -305,7 +306,7 @@ impl SessionSet { I::Signature: InferenceSignature, { SessionCell { - inner: self.0[I::Signature::KIND].clone(), + inner: self.0[I::Signature::OPERATION].clone(), marker: PhantomData, } } @@ -333,8 +334,10 @@ mod tests { use rstest::rstest; use crate::{ - infer::domain::InferenceKind, macros::tests::assert_debug_fmt_eq, - synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file, + infer::domain::{InferenceDomainImpl, InferenceOperationKind}, + macros::tests::assert_debug_fmt_eq, + synthesizer::InferenceRuntimeImpl, + test_util::open_default_vvm_file, }; use super::{super::InferenceSessionOptions, Status}; @@ -351,23 +354,23 @@ mod tests { let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false); let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); let session_options = enum_map! { - InferenceKind::PredictDuration - | InferenceKind::PredictIntonation => light_session_options, - InferenceKind::Decode => heavy_session_options, + InferenceOperationKind::PredictDuration + | InferenceOperationKind::PredictIntonation => light_session_options, + InferenceOperationKind::Decode => heavy_session_options, }; - let status = Status::::new(session_options); + let status = Status::::new(session_options); assert_eq!( light_session_options, - status.session_options[InferenceKind::PredictDuration], + status.session_options[InferenceOperationKind::PredictDuration], ); assert_eq!( light_session_options, - status.session_options[InferenceKind::PredictIntonation], + status.session_options[InferenceOperationKind::PredictIntonation], ); assert_eq!( heavy_session_options, - status.session_options[InferenceKind::Decode], + status.session_options[InferenceOperationKind::Decode], ); assert!(status.loaded_models.lock().unwrap().0.is_empty()); @@ -376,7 +379,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_load_model_works() { - let status = Status::::new( + let status = Status::::new( enum_map!(_ => InferenceSessionOptions::new(0, false)), ); let model = &open_default_vvm_file().await; @@ -389,7 +392,7 @@ mod tests { #[rstest] #[tokio::test] async fn status_is_model_loaded_works() { - let status = Status::::new( + let status = Status::::new( enum_map!(_ => InferenceSessionOptions::new(0, false)), ); let vvm = open_default_vvm_file().await; diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 71628ef96..318983dfb 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -2,8 +2,9 @@ use enum_map::enum_map; use crate::infer::{ domain::{ - DecodeInput, DecodeOutput, InferenceKind, PredictDurationInput, PredictDurationOutput, - PredictIntonationInput, PredictIntonationOutput, + DecodeInput, DecodeOutput, InferenceDomainImpl, InferenceOperationKind, + PredictDurationInput, PredictDurationOutput, PredictIntonationInput, + PredictIntonationOutput, }, status::Status, InferenceRuntime, InferenceSessionOptions, @@ -14,7 +15,7 @@ use super::*; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; pub(crate) struct InferenceCore { - status: Status, + status: Status, } impl InferenceCore { @@ -27,9 +28,9 @@ impl InferenceCore { let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); let status = Status::new(enum_map! { - InferenceKind::PredictDuration - | InferenceKind::PredictIntonation => light_session_options, - InferenceKind::Decode => heavy_session_options, + InferenceOperationKind::PredictDuration + | InferenceOperationKind::PredictIntonation => light_session_options, + InferenceOperationKind::Decode => heavy_session_options, }); Ok(Self { status }) } else { diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 96494eab5..4a165261b 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -4,7 +4,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::domain::InferenceKind; +use crate::infer::domain::InferenceOperationKind; use std::{ collections::{BTreeMap, HashMap}, io, @@ -40,7 +40,7 @@ pub struct VoiceModel { impl VoiceModel { pub(crate) async fn read_inference_models( &self, - ) -> LoadModelResult>> { + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs index 69bc809a5..be5075898 100644 --- a/crates/voicevox_core_macros/src/inference_domain.rs +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -7,12 +7,13 @@ use syn::{ ItemType, Type, Variant, }; -pub(crate) fn derive_inference_domain( +pub(crate) fn derive_inference_operation( input: &DeriveInput, ) -> syn::Result { let DeriveInput { + attrs, vis, - ident: domain_name, + ident: operation_ty_name, generics, data, .. @@ -20,16 +21,27 @@ pub(crate) fn derive_inference_domain( deny_generics(generics)?; + let AssocTypeDomain(domain_ty) = attrs + .iter() + .find(|a| a.path().is_ident("inference_operation")) + .ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `#[inference_operation(…)]`", + ) + })? + .parse_args()?; + let variants = unit_enum_variants(data)? .into_iter() .map(|(attrs, variant_name)| { let AssocTypes { input, output } = attrs .iter() - .find(|a| a.path().is_ident("inference_domain")) + .find(|a| a.path().is_ident("inference_operation")) .ok_or_else(|| { syn::Error::new( proc_macro2::Span::call_site(), - "missing `#[inference_domain(…)]`", + "missing `#[inference_operation(…)]`", ) })? .parse_args()?; @@ -47,16 +59,18 @@ pub(crate) fn derive_inference_domain( #vis enum #variant_name {} impl crate::infer::InferenceSignature for #variant_name { - type Domain = #domain_name; + type Domain = #domain_ty; type Input = #input_ty; type Output = #output_ty; - const KIND: Self::Domain = #domain_name :: #variant_name; + + const OPERATION: ::Operation = + #operation_ty_name :: #variant_name; } } }); return Ok(quote! { - impl crate::infer::InferenceDomain for #domain_name { + impl crate::infer::InferenceOperation for #operation_ty_name { const PARAM_INFOS: ::enum_map::EnumMap< Self, ( @@ -74,6 +88,19 @@ pub(crate) fn derive_inference_domain( #(#signatures)* }); + struct AssocTypeDomain(Type); + + impl Parse for AssocTypeDomain { + fn parse(input: ParseStream<'_>) -> syn::Result { + let ItemType { ident, ty, .. } = input.parse()?; + + if ident != "Domain" { + return Err(syn::Error::new(ident.span(), "expected `Domain`")); + } + Ok(Self(*ty)) + } + } + struct AssocTypes { input: Type, output: Type, diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index e12cdc27e..d637033a9 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -10,23 +10,30 @@ use syn::parse_macro_input; /// /// ``` /// use enum_map::Enum; -/// use macros::InferenceDomain; +/// use macros::InferenceOperation; /// -/// #[derive(Clone, Copy, Enum, InferenceDomain)] -/// pub(crate) enum InferenceKind { -/// #[inference_domain( +/// impl InferenceDomain for InferenceDomainImpl { +/// type Operation = InferenceOperationKind; +/// } +/// +/// #[derive(Clone, Copy, Enum, InferenceOperation)] +/// #[inference_operation( +/// type Domain = InferenceDomainImpl; +/// )] +/// pub(crate) enum InferenceOperationKind { +/// #[inference_operation( /// type Input = PredictDurationInput; /// type Output = PredictDurationOutput; /// )] /// PredictDuration, /// -/// #[inference_domain( +/// #[inference_operation( /// type Input = PredictIntonationInput; /// type Output = PredictIntonationOutput; /// )] /// PredictIntonation, /// -/// #[inference_domain( +/// #[inference_operation( /// type Input = DecodeInput; /// type Output = DecodeOutput; /// )] @@ -34,10 +41,10 @@ use syn::parse_macro_input; /// } /// ``` #[cfg(not(doctest))] -#[proc_macro_derive(InferenceDomain, attributes(inference_domain))] -pub fn derive_inference_domain(input: proc_macro::TokenStream) -> proc_macro::TokenStream { +#[proc_macro_derive(InferenceOperation, attributes(inference_operation))] +pub fn derive_inference_operation(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = &parse_macro_input!(input); - from_syn(inference_domain::derive_inference_domain(input)) + from_syn(inference_domain::derive_inference_operation(input)) } /// `voicevox_core`内で、`crate::infer::InferenceInputSignature`を実装する。 From ad222c98cd1e47a4dbb920d742bf2844bf3cd442 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 01:03:14 +0900 Subject: [PATCH 39/47] =?UTF-8?q?`InferenceOperationKind`=20=E2=86=92=20`I?= =?UTF-8?q?nferenceOperationImpl`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer/domain.rs | 4 ++-- crates/voicevox_core/src/infer/status.rs | 14 +++++++------- crates/voicevox_core/src/inference_core.rs | 8 ++++---- crates/voicevox_core/src/voice_model.rs | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/crates/voicevox_core/src/infer/domain.rs b/crates/voicevox_core/src/infer/domain.rs index 1decca75e..bb83886dd 100644 --- a/crates/voicevox_core/src/infer/domain.rs +++ b/crates/voicevox_core/src/infer/domain.rs @@ -9,14 +9,14 @@ use super::{ pub(crate) enum InferenceDomainImpl {} impl InferenceDomain for InferenceDomainImpl { - type Operation = InferenceOperationKind; + type Operation = InferenceOperationImpl; } #[derive(Clone, Copy, Enum, InferenceOperation)] #[inference_operation( type Domain = InferenceDomainImpl; )] -pub(crate) enum InferenceOperationKind { +pub(crate) enum InferenceOperationImpl { #[inference_operation( type Input = PredictDurationInput; type Output = PredictDurationOutput; diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs index 72e98bd9f..7903cb8ff 100644 --- a/crates/voicevox_core/src/infer/status.rs +++ b/crates/voicevox_core/src/infer/status.rs @@ -334,7 +334,7 @@ mod tests { use rstest::rstest; use crate::{ - infer::domain::{InferenceDomainImpl, InferenceOperationKind}, + infer::domain::{InferenceDomainImpl, InferenceOperationImpl}, macros::tests::assert_debug_fmt_eq, synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file, @@ -354,23 +354,23 @@ mod tests { let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false); let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); let session_options = enum_map! { - InferenceOperationKind::PredictDuration - | InferenceOperationKind::PredictIntonation => light_session_options, - InferenceOperationKind::Decode => heavy_session_options, + InferenceOperationImpl::PredictDuration + | InferenceOperationImpl::PredictIntonation => light_session_options, + InferenceOperationImpl::Decode => heavy_session_options, }; let status = Status::::new(session_options); assert_eq!( light_session_options, - status.session_options[InferenceOperationKind::PredictDuration], + status.session_options[InferenceOperationImpl::PredictDuration], ); assert_eq!( light_session_options, - status.session_options[InferenceOperationKind::PredictIntonation], + status.session_options[InferenceOperationImpl::PredictIntonation], ); assert_eq!( heavy_session_options, - status.session_options[InferenceOperationKind::Decode], + status.session_options[InferenceOperationImpl::Decode], ); assert!(status.loaded_models.lock().unwrap().0.is_empty()); diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 318983dfb..875c9ba64 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -2,7 +2,7 @@ use enum_map::enum_map; use crate::infer::{ domain::{ - DecodeInput, DecodeOutput, InferenceDomainImpl, InferenceOperationKind, + DecodeInput, DecodeOutput, InferenceDomainImpl, InferenceOperationImpl, PredictDurationInput, PredictDurationOutput, PredictIntonationInput, PredictIntonationOutput, }, @@ -28,9 +28,9 @@ impl InferenceCore { let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); let status = Status::new(enum_map! { - InferenceOperationKind::PredictDuration - | InferenceOperationKind::PredictIntonation => light_session_options, - InferenceOperationKind::Decode => heavy_session_options, + InferenceOperationImpl::PredictDuration + | InferenceOperationImpl::PredictIntonation => light_session_options, + InferenceOperationImpl::Decode => heavy_session_options, }); Ok(Self { status }) } else { diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 4a165261b..829bbf43d 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -4,7 +4,7 @@ use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; -use crate::infer::domain::InferenceOperationKind; +use crate::infer::domain::InferenceOperationImpl; use std::{ collections::{BTreeMap, HashMap}, io, @@ -40,7 +40,7 @@ pub struct VoiceModel { impl VoiceModel { pub(crate) async fn read_inference_models( &self, - ) -> LoadModelResult>> { + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( From 7005c96a74cb49789eaf07d7e2411163fca09bbc Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 08:33:07 +0900 Subject: [PATCH 40/47] =?UTF-8?q?doc=E3=82=92=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core_macros/src/lib.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index d637033a9..93abbc4da 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -12,15 +12,17 @@ use syn::parse_macro_input; /// use enum_map::Enum; /// use macros::InferenceOperation; /// +/// pub(crate) enum InferenceDomainImpl {} +/// /// impl InferenceDomain for InferenceDomainImpl { -/// type Operation = InferenceOperationKind; +/// type Operation = InferenceOperationImpl; /// } /// /// #[derive(Clone, Copy, Enum, InferenceOperation)] /// #[inference_operation( /// type Domain = InferenceDomainImpl; /// )] -/// pub(crate) enum InferenceOperationKind { +/// pub(crate) enum InferenceOperationImpl { /// #[inference_operation( /// type Input = PredictDurationInput; /// type Output = PredictDurationOutput; From 1655719e178680e66b1ea6baa2fa3150f42949c5 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 08:36:31 +0900 Subject: [PATCH 41/47] =?UTF-8?q?"voicevox=5Fcore=E5=86=85=E3=81=A7"=20?= =?UTF-8?q?=E2=86=92=20"Rust=20API=E3=82=AF=E3=83=AC=E3=83=BC=E3=83=88?= =?UTF-8?q?=E5=86=85=E3=81=A7"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core_macros/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index 93abbc4da..a197475da 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -4,7 +4,7 @@ mod inference_domain; use syn::parse_macro_input; -/// `voicevox_core`内で、`crate::infer::InferenceDomain`を実装する。 +/// Rust APIクレート内で、`crate::infer::InferenceDomain`を実装する。 /// /// # Example /// @@ -49,7 +49,7 @@ pub fn derive_inference_operation(input: proc_macro::TokenStream) -> proc_macro: from_syn(inference_domain::derive_inference_operation(input)) } -/// `voicevox_core`内で、`crate::infer::InferenceInputSignature`を実装する。 +/// Rust APIクレート内で、`crate::infer::InferenceInputSignature`を実装する。 /// /// # Example /// @@ -72,7 +72,7 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ from_syn(inference_domain::derive_inference_input_signature(input)) } -/// `voicevox_core`内で、`crate::infer::InferenceInputSignature`を実装する。 +/// Rust APIクレート内で、`crate::infer::InferenceInputSignature`を実装する。 /// /// # Example /// From a73f22c3a4dc76ea93d3af6b9003d71fc25659ae Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 08:53:32 +0900 Subject: [PATCH 42/47] =?UTF-8?q?doc=E3=82=92=E8=BF=BD=E8=A8=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 15 +++++++++++++++ crates/voicevox_core_macros/src/lib.rs | 12 +++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index e3698935e..4edef17f8 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -36,10 +36,16 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } +/// `VoiceModel`に対応する、 pub(crate) trait InferenceDomain { type Operation: InferenceOperation; } +/// `InferenceDomain`の推論操作を表す列挙型。 +/// +/// それぞれのバリアントには、対応する`InferenceSignature`が存在するべきである。 +/// +/// `::macros::InferenceOperation`により導出される。 pub(crate) trait InferenceOperation: Copy + Enum { /// `{InferenceInputSignature,InferenceOutputSignature}::PARAM_INFOS`を集めたもの。 /// @@ -54,6 +60,9 @@ pub(crate) trait InferenceOperation: Copy + Enum { >; } +/// `InferenceDomain`の推論操作を表す列挙型。 +/// +/// `::macros::InferenceOperation`により、具体型ごと生成される。 pub(crate) trait InferenceSignature: Sized + Send + 'static { type Domain: InferenceDomain; type Input: InferenceInputSignature; @@ -61,6 +70,9 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static { const OPERATION: ::Operation; } +/// 推論操作の入力シグネチャ。 +/// +/// `::macros::InferenceInputSignature`により導出される。 pub(crate) trait InferenceInputSignature: Send + 'static { type Signature: InferenceSignature; const PARAM_INFOS: &'static [ParamInfo]; @@ -88,6 +100,9 @@ pub(crate) enum InputScalarKind { Float32, } +/// 推論操作の出力シグネチャ。 +/// +/// `::macros::InferenceOutputSignature`により、`TryFrom`も含めて導出される。 pub(crate) trait InferenceOutputSignature: TryFrom, Error = anyhow::Error> + Send { diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index a197475da..5f2f26809 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -4,7 +4,12 @@ mod inference_domain; use syn::parse_macro_input; -/// Rust APIクレート内で、`crate::infer::InferenceDomain`を実装する。 +/// Rust APIクレート内で、`crate::infer::InferenceDomain`の導出などを行う。 +/// +/// 次のことを行う。 +/// +/// - `InferenceDomain`の導出 +/// - 各バリアントに対する`InferenceInputSignature`の実装を、型ごと生成 /// /// # Example /// @@ -49,7 +54,7 @@ pub fn derive_inference_operation(input: proc_macro::TokenStream) -> proc_macro: from_syn(inference_domain::derive_inference_operation(input)) } -/// Rust APIクレート内で、`crate::infer::InferenceInputSignature`を実装する。 +/// Rust APIクレート内で、`crate::infer::InferenceInputSignature`を導出する。 /// /// # Example /// @@ -72,7 +77,8 @@ pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_ from_syn(inference_domain::derive_inference_input_signature(input)) } -/// Rust APIクレート内で、`crate::infer::InferenceInputSignature`を実装する。 +/// Rust APIクレート内で`crate::infer::InferenceInputSignature`を、`TryFrom`ごと導出 +/// する。 /// /// # Example /// From f17919b694ac29cbea4a8e3caaad900d77feeebd Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 08:54:40 +0900 Subject: [PATCH 43/47] =?UTF-8?q?doc=E3=82=92=E8=BF=BD=E8=A8=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 4edef17f8..255f07a71 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -102,7 +102,7 @@ pub(crate) enum InputScalarKind { /// 推論操作の出力シグネチャ。 /// -/// `::macros::InferenceOutputSignature`により、`TryFrom`も含めて導出される。 +/// `::macros::InferenceOutputSignature`により、`TryFrom`も含めて導出される。 pub(crate) trait InferenceOutputSignature: TryFrom, Error = anyhow::Error> + Send { From 48bdb1b74470eca0fd2ce033459dedbb6700d9af Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 10:03:36 +0900 Subject: [PATCH 44/47] =?UTF-8?q?`InferenceDomain`=E3=81=AEdoc=E3=82=92?= =?UTF-8?q?=E6=9B=B8=E3=81=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 255f07a71..27f04e09b 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -36,7 +36,7 @@ pub(crate) trait InferenceRuntime: 'static { fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; } -/// `VoiceModel`に対応する、 +/// ある`VoiceModel`が提供する推論操作の集合を示す。 pub(crate) trait InferenceDomain { type Operation: InferenceOperation; } From 9d7d001b81482fd31a29304766c5c4274c82b46d Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 10:04:02 +0900 Subject: [PATCH 45/47] =?UTF-8?q?=E4=B8=8D=E8=A6=81=E3=81=AA=E6=96=87?= =?UTF-8?q?=E3=81=AE=E5=89=8A=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 27f04e09b..2657edbaa 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -48,8 +48,6 @@ pub(crate) trait InferenceDomain { /// `::macros::InferenceOperation`により導出される。 pub(crate) trait InferenceOperation: Copy + Enum { /// `{InferenceInputSignature,InferenceOutputSignature}::PARAM_INFOS`を集めたもの。 - /// - /// マクロ(voicevox_core_macros)で実装される前提。 #[allow(clippy::type_complexity)] const PARAM_INFOS: EnumMap< Self, From af828eb943b1b43542648ac8c7280d3001591c89 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 10:04:29 +0900 Subject: [PATCH 46/47] Minor refactor --- crates/voicevox_core/src/infer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 2657edbaa..bfc74b08f 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -65,7 +65,7 @@ pub(crate) trait InferenceSignature: Sized + Send + 'static { type Domain: InferenceDomain; type Input: InferenceInputSignature; type Output: InferenceOutputSignature; - const OPERATION: ::Operation; + const OPERATION: ::Operation; } /// 推論操作の入力シグネチャ。 From b6b7975278b396ad6bbc225ab98b6185830c3613 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 10:18:02 +0900 Subject: [PATCH 47/47] =?UTF-8?q?`ArrayExt`=E3=82=92=E3=83=9E=E3=82=AF?= =?UTF-8?q?=E3=83=AD=E5=86=85=E3=81=AB=E6=8A=BC=E3=81=97=E8=BE=BC=E3=82=81?= =?UTF-8?q?=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer.rs | 10 ---- .../src/inference_domain.rs | 52 ++++++++++++------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index bfc74b08f..c6b81348a 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -155,16 +155,6 @@ impl ParamInfo { } } -pub(crate) trait ArrayExt { - type Scalar; - type Dimension: Dimension; -} - -impl ArrayExt for Array { - type Scalar = A; - type Dimension = D; -} - #[derive(new, Clone, Copy, PartialEq, Debug)] pub(crate) struct InferenceSessionOptions { pub(crate) cpu_num_threads: u16, diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs index be5075898..4a447d37d 100644 --- a/crates/voicevox_core_macros/src/inference_domain.rs +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -187,12 +187,8 @@ pub(crate) fn derive_inference_input_signature( quote! { crate::infer::ParamInfo { name: ::std::borrow::Cow::Borrowed(#name), - dt: < - <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::InputScalar - >::KIND, - ndim: < - <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension - >::NDIM, + dt: <<#ty as __ArrayExt>::Scalar as crate::infer::InputScalar>::KIND, + ndim: <<#ty as __ArrayExt>::Dimension as ::ndarray::Dimension>::NDIM, }, } }) @@ -208,9 +204,21 @@ pub(crate) fn derive_inference_input_signature( const PARAM_INFOS: &'static [crate::infer::ParamInfo< crate::infer::InputScalarKind - >] = &[ - #param_infos - ]; + >] = { + trait __ArrayExt { + type Scalar: crate::infer::InputScalar; + type Dimension: ::ndarray::Dimension + 'static; + } + + impl __ArrayExt + for ::ndarray::Array + { + type Scalar = A; + type Dimension = D; + } + + &[#param_infos] + }; fn make_run_context( self, @@ -261,12 +269,8 @@ pub(crate) fn derive_inference_output_signature( quote! { crate::infer::ParamInfo { name: ::std::borrow::Cow::Borrowed(#name), - dt: < - <#ty as crate::infer::ArrayExt>::Scalar as crate::infer::OutputScalar - >::KIND, - ndim: < - <#ty as crate::infer::ArrayExt>::Dimension as ::ndarray::Dimension - >::NDIM, + dt: <<#ty as __ArrayExt>::Scalar as crate::infer::OutputScalar>::KIND, + ndim: <<#ty as __ArrayExt>::Dimension as ::ndarray::Dimension>::NDIM, }, } }) @@ -280,9 +284,21 @@ pub(crate) fn derive_inference_output_signature( { const PARAM_INFOS: &'static [crate::infer::ParamInfo< crate::infer::OutputScalarKind - >] = &[ - #param_infos - ]; + >] = { + trait __ArrayExt { + type Scalar: crate::infer::OutputScalar; + type Dimension: ::ndarray::Dimension + 'static; + } + + impl __ArrayExt + for ::ndarray::Array + { + type Scalar = A; + type Dimension = D; + } + + &[#param_infos] + }; } impl #impl_generics ::std::convert::TryFrom<::std::vec::Vec>