Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Runtimeとモデルのシグネチャを隔離する #675

Merged
merged 48 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
5ff2b59
ONNX Runtimeとモデルのシグネチャを隔離する
qryxip Nov 5, 2023
33245ff
`R: InferenceCore`を`SynthesisEngine`まで持っていく
qryxip Nov 5, 2023
4bd2281
`SupportedDevices::create`の実装を移動
qryxip Nov 5, 2023
f959911
不要なreexportを削除
qryxip Nov 5, 2023
55b04d3
`InferenceModels`の定義を`signatures`に移動
qryxip Nov 5, 2023
c38ad99
`ErrorRepr::GetSupportedDevices`の中身を`anyhow::Error`に
qryxip Nov 6, 2023
192417f
enum-map v3.0.0-beta.1を導入し、`EnumMap`駆動に
qryxip Nov 6, 2023
20db67a
Minor refactor
qryxip Nov 6, 2023
cc84068
Minor refactor
qryxip Nov 6, 2023
e0f29c6
色々再構成
qryxip Nov 8, 2023
cb1db34
Fix up
qryxip Nov 8, 2023
c3e08dd
`OnnxruntimeInferenceBuilder` → `OnnxruntimeRunContext`
qryxip Nov 8, 2023
e4b91ab
`impl SupportsInferenceOutput<_> for Onnxruntime`を拡張
qryxip Nov 8, 2023
8584d27
`SignatureKind` → `InferenceSignatureKind`
qryxip Nov 8, 2023
4795309
`LoadedModels`へのアクセスをメソッド越しにするのを徹底する
qryxip Nov 8, 2023
a5dbbdd
Minor refactor
qryxip Nov 8, 2023
525f4b1
`InferenceInput` → `InferenceInputSignature`
qryxip Nov 8, 2023
26476f5
相互参照
qryxip Nov 8, 2023
fbd7d1c
`fn input`まわりを明瞭にする
qryxip Nov 9, 2023
8b4f3b6
"signature"のkindではなく"model"のkindとする
qryxip Nov 9, 2023
c40afd5
"model"ではなく"inference"と呼ぶ
qryxip Nov 11, 2023
81b5804
ランタイムは任意次元任意個数の入出力ができると仮定する
qryxip Nov 11, 2023
120106b
voicevox_core_macrosを作り、"signatures"の実装をマクロ化
qryxip Nov 11, 2023
590ce48
`AnyTensor` → `OutputTensor`
qryxip Nov 11, 2023
c4d5ebe
`INFERENCE` → `KIND`
qryxip Nov 11, 2023
1b1b7bf
`status`を`infer`下に
qryxip Nov 11, 2023
c39f48c
`trait RunContext`を削除
qryxip Nov 11, 2023
c316209
"kind"を直接"group"と呼ぶことにする
qryxip Nov 11, 2023
2274a34
シグネチャの実行時チェック機構を入れる
qryxip Nov 12, 2023
b7d48f3
signaturesのマクロ化を完了させる
qryxip Nov 13, 2023
b6db1c0
Minor refactor
qryxip Nov 13, 2023
96a93e9
`InferenceGroup` → `InferenceDomain`
qryxip Nov 14, 2023
59d8779
Minor refactor
qryxip Nov 14, 2023
d0dc56f
`InferenceDomain::{INPUT,OUTPUT}_PARAM_INFOS`を統合
qryxip Nov 14, 2023
c654cd1
`InferenceDomain::PARAM_INFOS`にdocstring
qryxip Nov 14, 2023
868d3f6
voicevox_core_macrosにdocstring
qryxip Nov 14, 2023
0998793
`sealed::InputScalar`にFIXME
qryxip Nov 14, 2023
75fd7ac
"Domain"と"Operation"に分離
qryxip Nov 14, 2023
ad222c9
`InferenceOperationKind` → `InferenceOperationImpl`
qryxip Nov 15, 2023
7005c96
docを修正
qryxip Nov 15, 2023
9417992
Merge branch 'main' into split-onnxruntime-and-model-signatures
qryxip Nov 15, 2023
1655719
"voicevox_core内で" → "Rust APIクレート内で"
qryxip Nov 15, 2023
a73f22c
docを追記
qryxip Nov 15, 2023
f17919b
docを追記
qryxip Nov 15, 2023
48bdb1b
`InferenceDomain`のdocを書く
qryxip Nov 16, 2023
9d7d001
不要な文の削除
qryxip Nov 16, 2023
af828eb
Minor refactor
qryxip Nov 16, 2023
b6b7975
`ArrayExt`をマクロ内に押し込める
qryxip Nov 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 3 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
[workspace]
members = [
"crates/downloader",
"crates/test_util",
"crates/voicevox_core",
"crates/voicevox_core_c_api",
"crates/voicevox_core_java_api",
"crates/voicevox_core_python_api",
"crates/xtask"
]
members = ["crates/*"]
resolver = "2"

[workspace.dependencies]
Expand All @@ -18,7 +10,9 @@ derive_more = "0.99.17"
easy-ext = "1.0.1"
fs-err = { version = "2.9.0", features = ["tokio"] }
futures = "0.3.26"
indexmap = { version = "2.0.0", features = ["serde"] }
itertools = "0.10.5"
ndarray = "0.15.6"
once_cell = "1.18.0"
regex = "1.10.0"
rstest = "0.15.0"
Expand Down
6 changes: 5 additions & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ derive-new = "0.5.9"
derive_more.workspace = true
duplicate = "1.0.0"
easy-ext.workspace = true
educe = "0.4.23"
enum-map = "3.0.0-beta.1"
fs-err.workspace = true
futures.workspace = true
indexmap = { version = "2.0.0", features = ["serde"] }
indexmap.workspace = true
itertools.workspace = true
nanoid = "0.4.0"
ndarray.workspace = true
once_cell.workspace = true
regex.workspace = true
serde.workspace = true
Expand All @@ -31,6 +34,7 @@ thiserror.workspace = true
tokio.workspace = true
tracing.workspace = true
uuid.workspace = true
voicevox_core_macros = { path = "../voicevox_core_macros" }

[dependencies.onnxruntime]
git = "https://github.com/VOICEVOX/onnxruntime-rs.git"
Expand Down
26 changes: 5 additions & 21 deletions crates/voicevox_core/src/devices.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize};

use super::*;
use crate::{infer::InferenceRuntime, synthesizer::InferenceRuntimeImpl};

/// このライブラリで利用可能なデバイスの情報。
///
Expand All @@ -11,21 +12,21 @@ pub struct SupportedDevices {
/// CPUが利用可能。
///
/// 常に`true`。
cpu: bool,
pub cpu: bool,
/// CUDAが利用可能。
///
/// ONNX Runtimeの[CUDA Execution Provider] (`CUDAExecutionProvider`)に対応する。必要な環境につ
/// いてはそちらを参照。
///
/// [CUDA Execution Provider]: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html
cuda: bool,
pub cuda: bool,
/// DirectMLが利用可能。
///
/// ONNX Runtimeの[DirectML Execution Provider] (`DmlExecutionProvider`)に対応する。必要な環境に
/// ついてはそちらを参照。
///
/// [DirectML Execution Provider]: https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html
dml: bool,
pub dml: bool,
}

impl SupportedDevices {
Expand All @@ -42,24 +43,7 @@ impl SupportedDevices {
/// # Result::<_, anyhow::Error>::Ok(())
/// ```
pub fn create() -> Result<Self> {
let mut cuda_support = false;
let mut dml_support = false;
for provider in onnxruntime::session::get_available_providers()
.map_err(ErrorRepr::GetSupportedDevices)?
.iter()
{
match provider.as_str() {
"CUDAExecutionProvider" => cuda_support = true,
"DmlExecutionProvider" => dml_support = true,
_ => {}
}
}

Ok(SupportedDevices {
cpu: true,
cuda: cuda_support,
dml: dml_support,
})
<InferenceRuntimeImpl as InferenceRuntime>::supported_devices()
}

pub fn to_json(&self) -> serde_json::Value {
Expand Down
54 changes: 18 additions & 36 deletions crates/voicevox_core/src/engine/synthesis_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use super::full_context_label::Utterance;
use super::open_jtalk::OpenJtalk;
use super::*;
use crate::infer::InferenceRuntime;
use crate::numerics::F32Ext as _;
use crate::InferenceCore;

Expand All @@ -14,19 +15,16 @@ const MORA_PHONEME_LIST: &[&str] = &[
"a", "i", "u", "e", "o", "N", "A", "I", "U", "E", "O", "cl", "pau",
];

pub const DEFAULT_SAMPLING_RATE: u32 = 24000;

#[derive(new)]
pub struct SynthesisEngine {
inference_core: InferenceCore,
pub(crate) struct SynthesisEngine<R: InferenceRuntime> {
inference_core: InferenceCore<R>,
open_jtalk: Arc<OpenJtalk>,
}

#[allow(unsafe_code)]
unsafe impl Send for SynthesisEngine {}

impl SynthesisEngine {
pub const DEFAULT_SAMPLING_RATE: u32 = 24000;

pub fn inference_core(&self) -> &InferenceCore {
impl<R: InferenceRuntime> SynthesisEngine<R> {
pub fn inference_core(&self) -> &InferenceCore<R> {
&self.inference_core
}

Expand Down Expand Up @@ -123,7 +121,7 @@ impl SynthesisEngine {
accent_phrases: &[AccentPhraseModel],
style_id: StyleId,
) -> Result<Vec<AccentPhraseModel>> {
let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases);
let (_, phoneme_data_list) = Self::initial_process(accent_phrases);

let (_, _, vowel_indexes_data) = split_mora(&phoneme_data_list);

Expand Down Expand Up @@ -185,36 +183,20 @@ impl SynthesisEngine {
accent_phrases: &[AccentPhraseModel],
style_id: StyleId,
) -> Result<Vec<AccentPhraseModel>> {
let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases);
let (_, phoneme_data_list) = Self::initial_process(accent_phrases);

let mut base_start_accent_list = vec![0];
let mut base_end_accent_list = vec![0];
let mut base_start_accent_phrase_list = vec![0];
let mut base_end_accent_phrase_list = vec![0];
for accent_phrase in accent_phrases {
let mut accent = usize::from(*accent_phrase.accent() != 1);
SynthesisEngine::create_one_accent_list(
&mut base_start_accent_list,
accent_phrase,
accent as i32,
);
Self::create_one_accent_list(&mut base_start_accent_list, accent_phrase, accent as i32);

accent = *accent_phrase.accent() - 1;
SynthesisEngine::create_one_accent_list(
&mut base_end_accent_list,
accent_phrase,
accent as i32,
);
SynthesisEngine::create_one_accent_list(
&mut base_start_accent_phrase_list,
accent_phrase,
0,
);
SynthesisEngine::create_one_accent_list(
&mut base_end_accent_phrase_list,
accent_phrase,
-1,
);
Self::create_one_accent_list(&mut base_end_accent_list, accent_phrase, accent as i32);
Self::create_one_accent_list(&mut base_start_accent_phrase_list, accent_phrase, 0);
Self::create_one_accent_list(&mut base_end_accent_phrase_list, accent_phrase, -1);
}
base_start_accent_list.push(0);
base_end_accent_list.push(0);
Expand Down Expand Up @@ -328,7 +310,7 @@ impl SynthesisEngine {
query.accent_phrases().clone()
};

let (flatten_moras, phoneme_data_list) = SynthesisEngine::initial_process(&accent_phrases);
let (flatten_moras, phoneme_data_list) = Self::initial_process(&accent_phrases);

let mut phoneme_length_list = vec![pre_phoneme_length];
let mut f0_list = vec![0.];
Expand Down Expand Up @@ -440,7 +422,7 @@ impl SynthesisEngine {
let num_channels: u16 = if output_stereo { 2 } else { 1 };
let bit_depth: u16 = 16;
let repeat_count: u32 =
(output_sampling_rate / Self::DEFAULT_SAMPLING_RATE) * num_channels as u32;
(output_sampling_rate / DEFAULT_SAMPLING_RATE) * num_channels as u32;
let block_size: u16 = bit_depth * num_channels / 8;

let bytes_size = wave.len() as u32 * repeat_count * 2;
Expand Down Expand Up @@ -647,12 +629,12 @@ mod tests {
use ::test_util::OPEN_JTALK_DIC_DIR;
use pretty_assertions::assert_eq;

use crate::*;
use crate::{synthesizer::InferenceRuntimeImpl, *};

#[rstest]
#[tokio::test]
async fn is_openjtalk_dict_loaded_works() {
let core = InferenceCore::new(false, 0).unwrap();
let core = InferenceCore::<InferenceRuntimeImpl>::new(false, 0).unwrap();
let synthesis_engine =
SynthesisEngine::new(core, OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap().into());

Expand All @@ -662,7 +644,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn create_accent_phrases_works() {
let core = InferenceCore::new(false, 0).unwrap();
let core = InferenceCore::<InferenceRuntimeImpl>::new(false, 0).unwrap();

let model = &VoiceModel::sample().await.unwrap();
core.load_model(model).await.unwrap();
Expand Down
3 changes: 1 addition & 2 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use self::engine::{FullContextLabelError, KanaParseError};
use super::*;
//use engine::
use duplicate::duplicate_item;
use onnxruntime::OrtError;
use std::path::PathBuf;
use thiserror::Error;
use uuid::Uuid;
Expand Down Expand Up @@ -65,7 +64,7 @@ pub(crate) enum ErrorRepr {
LoadModel(#[from] LoadModelError),

#[error("サポートされているデバイス情報取得中にエラーが発生しました")]
GetSupportedDevices(#[source] OrtError),
GetSupportedDevices(#[source] anyhow::Error),

#[error(
"`{style_id}`に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\
Expand Down
Loading
Loading