Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Transformers in the Wav2Vec2 Encoder for the ASR Inference #1520

Merged
merged 20 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ set(SOURCES
src/layers/common.cc
src/layers/decoder.cc
src/layers/transformer.cc
src/layers/wav2vec2.cc
src/layers/whisper.cc
src/logging.cc
src/models/language_model.cc
Expand All @@ -124,6 +125,7 @@ set(SOURCES
src/models/model_reader.cc
src/models/sequence_to_sequence.cc
src/models/transformer.cc
src/models/wav2vec2.cc
src/models/whisper.cc
src/ops/activation.cc
src/ops/add.cc
Expand Down
47 changes: 47 additions & 0 deletions include/ctranslate2/layers/wav2vec2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#pragma once

#include "ctranslate2/layers/transformer.h"

namespace ctranslate2 {
namespace layers {

class Wav2Vec2Encoder : public Layer {
public:
Wav2Vec2Encoder(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
return _output_norm.output_type();
}

dim_t output_size() const override {
return _output_norm.output_size();
}

dim_t input_size() const {
return 1024;
}

bool is_encoded(const StorageView& features) const {
// Input features shape: [batch_size, input_size, input_time]
// Encoder output shape: [batch_size, input_time // 2, output_size]
//
// input_time is variable so we check that dimension 1 is different than its original value.

return (features.rank() == 3
&& features.dim(2) == output_size()
&& features.dim(1) != input_size());
}

private:
const ops::GELU _gelu;
// wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
//const ops::Transpose _transpose;
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
};

}
}
72 changes: 72 additions & 0 deletions include/ctranslate2/models/wav2vec2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

//#include "ctranslate2/generation.h"
#include "ctranslate2/layers/wav2vec2.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"

namespace ctranslate2 {
namespace models {

struct Wav2Vec2Options {
// Maximum generation length.
size_t max_length = 448;

// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
size_t sampling_topk = 1;

// Maximum index of the first predicted timestamp.
size_t max_initial_timestamp_index = 50;

// Suppress blank outputs at the beginning of the sampling.
bool suppress_blank = true;

// List of token IDs to suppress.
// -1 will suppress a default set of symbols as defined in the model config.json file.
std::vector<int> suppress_tokens = {-1};
};


class Wav2Vec2Model : public Model {
public:
const Vocabulary& get_vocabulary() const;
size_t current_spec_revision() const override;
bool is_quantizable(const std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
std::unique_ptr<Model> clone() const override;

bool use_global_int16_scale() const override {
return false;
}

protected:
void initialize(ModelReader& model_reader) override;
private:
std::shared_ptr<const Vocabulary> _vocabulary;
};

class Wav2Vec2Replica : public ModelReplica {
public:
static std::unique_ptr<Wav2Vec2Replica> create_from_model(const Model& model);

Wav2Vec2Replica(const std::shared_ptr<const Wav2Vec2Model>& model);

StorageView encode(StorageView features, const bool to_cpu);

private:
const std::shared_ptr<const Wav2Vec2Model> _model;
const std::unique_ptr<layers::Wav2Vec2Encoder> _encoder;

StorageView maybe_encode(StorageView features);
};

class Wav2Vec2 : public ReplicaPool<Wav2Vec2Replica> {
public:
using ReplicaPool::ReplicaPool;

std::future<StorageView> encode(const StorageView& features, const bool to_cpu);

};

}
}
1 change: 1 addition & 0 deletions python/cpp/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_generator(m);
ctranslate2::python::register_encoder(m);
ctranslate2::python::register_whisper(m);
ctranslate2::python::register_wav2vec2(m);
}
1 change: 1 addition & 0 deletions python/cpp/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace ctranslate2 {
void register_translation_stats(py::module& m);
void register_translator(py::module& m);
void register_whisper(py::module& m);
void register_wav2vec2(py::module& m);

}
}
93 changes: 93 additions & 0 deletions python/cpp/wav2vec2.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include "module.h"

#include <ctranslate2/models/wav2vec2.h>

#include "replica_pool.h"

namespace ctranslate2 {
namespace python {

class Wav2Vec2Wrapper : public ReplicaPoolHelper<models::Wav2Vec2> {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;

StorageView encode(const StorageView& features, const bool to_cpu) {
return _pool->encode(features, to_cpu).get();
}
};


void register_wav2vec2(py::module& m) {
py::class_<Wav2Vec2Wrapper>(
m, "Wav2Vec2",
R"pbdoc(
Implements the Wav2Vec2 speech recognition model published by Facebook.

See Also:
https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec
)pbdoc")

.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
py::arg("device_index")=0,
py::arg("compute_type")="default",
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
py::arg("files")=py::none(),
R"pbdoc(
Initializes a Wav2Vec2 model from a converted model.

Arguments:
model_path: Path to the CTranslate2 model directory.
device: Device to use (possible values are: cpu, cuda, auto).
device_index: Device IDs where to place this model on.
compute_type: Model computation type or a dictionary mapping a device name
to the computation type (possible values are: default, auto, int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
inter_threads: Number of workers to allow executing multiple batches in parallel.
intra_threads: Number of OpenMP threads per worker (0 to use a default value).
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
)pbdoc")

.def_property_readonly("device", &Wav2Vec2Wrapper::device,
"Device this model is running on.")
.def_property_readonly("device_index", &Wav2Vec2Wrapper::device_index,
"List of device IDs where this model is running on.")
.def_property_readonly("compute_type", &Wav2Vec2Wrapper::compute_type,
"Computation type used by the model.")
.def_property_readonly("num_workers", &Wav2Vec2Wrapper::num_replicas,
"Number of model workers backing this instance.")
.def_property_readonly("num_queued_batches", &Wav2Vec2Wrapper::num_queued_batches,
"Number of batches waiting to be processed.")
.def_property_readonly("num_active_batches", &Wav2Vec2Wrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")

.def("encode", &Wav2Vec2Wrapper::encode,
py::arg("features"),
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Encodes the input features.

Arguments:
features: Mel spectogram of the audio, as a float array with shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this one take raw audio, not a mel spectrogram?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. It should be not Mel spectrogram but raw audio. How can we fix it? making another PR for this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a project maintaner/member/contributor, but I would guess so.

``[batch_size, 80, 3000]``.
to_cpu: Copy the encoder output to the CPU before returning the value.

Returns:
The encoder output.
)pbdoc")

;
}

}
}
44 changes: 44 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
common_spec,
model_spec,
transformer_spec,
wav2vec2_spec,
whisper_spec,
)

Expand Down Expand Up @@ -935,6 +936,49 @@ def set_conv1d(self, spec, module):
spec.bias = module.bias


@register_loader("Wav2Vec2Config")
class Wav2Vec2Loader(BartLoader):
@property
def architecture_name(self):
return "Wav2Vec2ForCTC"

def get_model_spec(self, model):
# Wav2Vec2 encoder Wav2Vec2PositionalConvEmbedding conv1d has groups 16
# that doesn't look available here so we make Wav2Vec2 encoder layers only
spec = wav2vec2_spec.Wav2Vec2Spec(
model.wav2vec2.encoder.config.num_hidden_layers,
model.wav2vec2.encoder.config.num_attention_heads,
)

# layer component name matching (no duplications saving)
for layer in model.wav2vec2.encoder.layers:
layer.self_attn = layer.attention
layer.self_attn_layer_norm = layer.layer_norm
layer.activation_fn = layer.feed_forward.intermediate_act_fn
layer.fc1 = layer.feed_forward.intermediate_dense
layer.fc2 = layer.feed_forward.output_dense

self.set_encoder(spec.encoder, model.wav2vec2.encoder)
self.set_linear(spec.lm_head, model.lm_head)
# only for Wav2Vec2Spec.get_vocabulary_size()
return spec

def set_config(self, config, model, tokenizer):
return

def get_vocabulary(self, model, tokenizer):
return tokenizer.get_vocab()

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_encoder(self, spec, encoder):
super().set_encoder(spec, encoder)

def set_common_layers(self, spec, module):
self.set_layer_norm(spec.layer_norm, module.layer_norm)


@register_loader("T5Config")
class T5Loader(ModelLoader):
@property
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

try:
from ctranslate2._ext import (
Wav2Vec2,
Whisper,
WhisperGenerationResult,
WhisperGenerationResultAsync,
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
TransformerEncoderSpec,
TransformerSpec,
)
from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec
from ctranslate2.specs.whisper_spec import WhisperSpec
43 changes: 43 additions & 0 deletions python/ctranslate2/specs/wav2vec2_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import List, Optional, Tuple

import numpy as np

from ctranslate2.specs import common_spec, model_spec, transformer_spec


class Wav2Vec2Config(model_spec.ModelConfig):
"""Configuration for the Wav2Vec2 model."""

def __init__(self):
return


class Wav2Vec2Spec(model_spec.LanguageModelSpec):
def __init__(self, num_layers, num_heads):
super().__init__()
self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads)
self.lm_head = common_spec.LinearSpec()

@property
def name(self):
return "Wav2Vec2Spec"

@property
def revision(self):
return 3

def get_default_config(self):
return Wav2Vec2Config()

def get_vocabulary_size(self):
return self.lm_head.weight.shape[0]


class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers, num_heads):
self.num_heads = np.dtype("int16").type(num_heads)
# wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
self.layer_norm = common_spec.LayerNormSpec()
self.layer = [
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
]
1 change: 1 addition & 0 deletions python/tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ OpenNMT-tf==2.30.*
tensorflow-cpu==2.11.*
pytest
wurlitzer==3.0.*;platform_system=='Linux'
torch
Loading