-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Transformers in the Wav2Vec2 Encoder for the ASR Inference #1520
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
15ee7da
Support Transformers in the Wav2Vec2 Encoder for the ASR Inference
0901484
code style/format check with flask8 & black
4b28cc6
check isort and update
5258816
change ONEAPI_VERSION to 2023.2.0
f9bfa16
add missing package (librosa) for test_wav2vec2.py
d2ff992
import package path update
4ffaab9
isort library update
7078fd5
update vocab return
eeb92ef
add packages requirement for test_wav2vec2.py
9af8b89
merge test_wav2vec2.py to test_transformers.py for the compatibility
d2b01ed
fix python style format
7e0dcdc
update audio_name for TestWav2Vec2
8403920
change the audio downloading
bf63c95
change the audio downloading
dad4b2c
change the audio downloading
a9674ba
add requests for test requirement
1e6aa47
update audio file downloading
478bde5
update audio file downloading path
11f5ff1
switch audio to the existing one
4370869
remove unnecessary audio downloading
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#pragma once | ||
|
||
#include "ctranslate2/layers/transformer.h" | ||
|
||
namespace ctranslate2 { | ||
namespace layers { | ||
|
||
class Wav2Vec2Encoder : public Layer { | ||
public: | ||
Wav2Vec2Encoder(const models::Model& model, const std::string& scope); | ||
|
||
void operator()(const StorageView& features, StorageView& output); | ||
|
||
DataType output_type() const override { | ||
return _output_norm.output_type(); | ||
} | ||
|
||
dim_t output_size() const override { | ||
return _output_norm.output_size(); | ||
} | ||
|
||
dim_t input_size() const { | ||
return 1024; | ||
} | ||
|
||
bool is_encoded(const StorageView& features) const { | ||
// Input features shape: [batch_size, input_size, input_time] | ||
// Encoder output shape: [batch_size, input_time // 2, output_size] | ||
// | ||
// input_time is variable so we check that dimension 1 is different than its original value. | ||
|
||
return (features.rank() == 3 | ||
&& features.dim(2) == output_size() | ||
&& features.dim(1) != input_size()); | ||
} | ||
|
||
private: | ||
const ops::GELU _gelu; | ||
// wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported | ||
//const ops::Transpose _transpose; | ||
const dim_t _num_heads; | ||
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers; | ||
const LayerNorm _output_norm; | ||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#pragma once | ||
|
||
//#include "ctranslate2/generation.h" | ||
#include "ctranslate2/layers/wav2vec2.h" | ||
#include "ctranslate2/models/model.h" | ||
#include "ctranslate2/replica_pool.h" | ||
|
||
namespace ctranslate2 { | ||
namespace models { | ||
|
||
struct Wav2Vec2Options { | ||
// Maximum generation length. | ||
size_t max_length = 448; | ||
|
||
// Randomly sample from the top K candidates (set 0 to sample from the full distribution). | ||
size_t sampling_topk = 1; | ||
|
||
// Maximum index of the first predicted timestamp. | ||
size_t max_initial_timestamp_index = 50; | ||
|
||
// Suppress blank outputs at the beginning of the sampling. | ||
bool suppress_blank = true; | ||
|
||
// List of token IDs to suppress. | ||
// -1 will suppress a default set of symbols as defined in the model config.json file. | ||
std::vector<int> suppress_tokens = {-1}; | ||
}; | ||
|
||
|
||
class Wav2Vec2Model : public Model { | ||
public: | ||
const Vocabulary& get_vocabulary() const; | ||
size_t current_spec_revision() const override; | ||
bool is_quantizable(const std::string& variable_name) const override; | ||
bool is_linear_weight(const std::string& variable_name) const override; | ||
std::unique_ptr<Model> clone() const override; | ||
|
||
bool use_global_int16_scale() const override { | ||
return false; | ||
} | ||
|
||
protected: | ||
void initialize(ModelReader& model_reader) override; | ||
private: | ||
std::shared_ptr<const Vocabulary> _vocabulary; | ||
}; | ||
|
||
class Wav2Vec2Replica : public ModelReplica { | ||
public: | ||
static std::unique_ptr<Wav2Vec2Replica> create_from_model(const Model& model); | ||
|
||
Wav2Vec2Replica(const std::shared_ptr<const Wav2Vec2Model>& model); | ||
|
||
StorageView encode(StorageView features, const bool to_cpu); | ||
|
||
private: | ||
const std::shared_ptr<const Wav2Vec2Model> _model; | ||
const std::unique_ptr<layers::Wav2Vec2Encoder> _encoder; | ||
|
||
StorageView maybe_encode(StorageView features); | ||
}; | ||
|
||
class Wav2Vec2 : public ReplicaPool<Wav2Vec2Replica> { | ||
public: | ||
using ReplicaPool::ReplicaPool; | ||
|
||
std::future<StorageView> encode(const StorageView& features, const bool to_cpu); | ||
|
||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#include "module.h" | ||
|
||
#include <ctranslate2/models/wav2vec2.h> | ||
|
||
#include "replica_pool.h" | ||
|
||
namespace ctranslate2 { | ||
namespace python { | ||
|
||
class Wav2Vec2Wrapper : public ReplicaPoolHelper<models::Wav2Vec2> { | ||
public: | ||
using ReplicaPoolHelper::ReplicaPoolHelper; | ||
|
||
StorageView encode(const StorageView& features, const bool to_cpu) { | ||
return _pool->encode(features, to_cpu).get(); | ||
} | ||
}; | ||
|
||
|
||
void register_wav2vec2(py::module& m) { | ||
py::class_<Wav2Vec2Wrapper>( | ||
m, "Wav2Vec2", | ||
R"pbdoc( | ||
Implements the Wav2Vec2 speech recognition model published by Facebook. | ||
|
||
See Also: | ||
https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec | ||
)pbdoc") | ||
|
||
.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, py::object>(), | ||
py::arg("model_path"), | ||
py::arg("device")="cpu", | ||
py::kw_only(), | ||
py::arg("device_index")=0, | ||
py::arg("compute_type")="default", | ||
py::arg("inter_threads")=1, | ||
py::arg("intra_threads")=0, | ||
py::arg("max_queued_batches")=0, | ||
py::arg("files")=py::none(), | ||
R"pbdoc( | ||
Initializes a Wav2Vec2 model from a converted model. | ||
|
||
Arguments: | ||
model_path: Path to the CTranslate2 model directory. | ||
device: Device to use (possible values are: cpu, cuda, auto). | ||
device_index: Device IDs where to place this model on. | ||
compute_type: Model computation type or a dictionary mapping a device name | ||
to the computation type (possible values are: default, auto, int8, int8_float32, | ||
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). | ||
inter_threads: Number of workers to allow executing multiple batches in parallel. | ||
intra_threads: Number of OpenMP threads per worker (0 to use a default value). | ||
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, | ||
0 for an automatic value). When the queue is full, future requests will block | ||
until a free slot is available. | ||
files: Load model files from the memory. This argument is a dictionary mapping | ||
file names to file contents as file-like or bytes objects. If this is set, | ||
:obj:`model_path` acts as an identifier for this model. | ||
)pbdoc") | ||
|
||
.def_property_readonly("device", &Wav2Vec2Wrapper::device, | ||
"Device this model is running on.") | ||
.def_property_readonly("device_index", &Wav2Vec2Wrapper::device_index, | ||
"List of device IDs where this model is running on.") | ||
.def_property_readonly("compute_type", &Wav2Vec2Wrapper::compute_type, | ||
"Computation type used by the model.") | ||
.def_property_readonly("num_workers", &Wav2Vec2Wrapper::num_replicas, | ||
"Number of model workers backing this instance.") | ||
.def_property_readonly("num_queued_batches", &Wav2Vec2Wrapper::num_queued_batches, | ||
"Number of batches waiting to be processed.") | ||
.def_property_readonly("num_active_batches", &Wav2Vec2Wrapper::num_active_batches, | ||
"Number of batches waiting to be processed or currently processed.") | ||
|
||
.def("encode", &Wav2Vec2Wrapper::encode, | ||
py::arg("features"), | ||
py::arg("to_cpu")=false, | ||
py::call_guard<py::gil_scoped_release>(), | ||
R"pbdoc( | ||
Encodes the input features. | ||
|
||
Arguments: | ||
features: Mel spectogram of the audio, as a float array with shape | ||
``[batch_size, 80, 3000]``. | ||
to_cpu: Copy the encoder output to the CPU before returning the value. | ||
|
||
Returns: | ||
The encoder output. | ||
)pbdoc") | ||
|
||
; | ||
} | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import List, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
from ctranslate2.specs import common_spec, model_spec, transformer_spec | ||
|
||
|
||
class Wav2Vec2Config(model_spec.ModelConfig): | ||
"""Configuration for the Wav2Vec2 model.""" | ||
|
||
def __init__(self): | ||
return | ||
|
||
|
||
class Wav2Vec2Spec(model_spec.LanguageModelSpec): | ||
def __init__(self, num_layers, num_heads): | ||
super().__init__() | ||
self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads) | ||
self.lm_head = common_spec.LinearSpec() | ||
|
||
@property | ||
def name(self): | ||
return "Wav2Vec2Spec" | ||
|
||
@property | ||
def revision(self): | ||
return 3 | ||
|
||
def get_default_config(self): | ||
return Wav2Vec2Config() | ||
|
||
def get_vocabulary_size(self): | ||
return self.lm_head.weight.shape[0] | ||
|
||
|
||
class Wav2Vec2EncoderSpec(model_spec.LayerSpec): | ||
def __init__(self, num_layers, num_heads): | ||
self.num_heads = np.dtype("int16").type(num_heads) | ||
# wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported | ||
self.layer_norm = common_spec.LayerNormSpec() | ||
self.layer = [ | ||
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ OpenNMT-tf==2.30.* | |
tensorflow-cpu==2.11.* | ||
pytest | ||
wurlitzer==3.0.*;platform_system=='Linux' | ||
torch |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this one take raw audio, not a mel spectrogram?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct. It should be not Mel spectrogram but raw audio. How can we fix it? making another PR for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a project maintaner/member/contributor, but I would guess so.