From 57958431764a055e8d2112fa3dce258f262de9cc Mon Sep 17 00:00:00 2001 From: homink Date: Tue, 5 Nov 2024 00:50:15 -0800 Subject: [PATCH 1/4] Backward compatibility for the Wav2Vec2 ASR model (#1810) * update description for the wav2vec2 model * the backward compatiability support for the wav2vec2 ASR model * dummy * dummy push * dummy push * dummy push * header update * dummpy push --------- Co-authored-by: hkwon --- include/ctranslate2/layers/wav2vec2.h | 14 +++-- python/cpp/wav2vec2.cc | 5 +- src/layers/wav2vec2.cc | 80 ++++++++++++++++----------- 3 files changed, 58 insertions(+), 41 deletions(-) diff --git a/include/ctranslate2/layers/wav2vec2.h b/include/ctranslate2/layers/wav2vec2.h index 29dea9783..6a7b5ca9d 100644 --- a/include/ctranslate2/layers/wav2vec2.h +++ b/include/ctranslate2/layers/wav2vec2.h @@ -1,5 +1,6 @@ #pragma once +#include #include "ctranslate2/layers/transformer.h" namespace ctranslate2 { @@ -81,17 +82,18 @@ namespace ctranslate2 { } private: - const Wav2Vec2LayerNormConvLayer _feat_layer0; - const std::vector> _feat_layers; - const LayerNorm _fp_norm; - const Dense _fp_ff; - const Wav2Vec2PosConvLayer _pos_conv_embed; + const StorageView* _upgraded_model; + std::optional _feat_layer0; + std::optional>> _feat_layers; + std::optional _fp_norm; + std::optional _fp_ff; + std::optional _pos_conv_embed; const ops::Transpose _transpose; const ops::GELU _gelu; const dim_t _num_heads; const std::vector> _layers; const LayerNorm _output_norm; - const Dense _lm_head; + std::optional _lm_head; }; } diff --git a/python/cpp/wav2vec2.cc b/python/cpp/wav2vec2.cc index f6791a451..b0ade0d9f 100644 --- a/python/cpp/wav2vec2.cc +++ b/python/cpp/wav2vec2.cc @@ -86,8 +86,9 @@ namespace ctranslate2 { Encodes the input features. Arguments: - features: Mel spectogram of the audio, as a float array with shape - ``[batch_size, 80, 3000]``. + features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or + raw audio, as a float array with shape (followed by VAD) + ``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]`` to_cpu: Copy the encoder output to the CPU before returning the value. Returns: diff --git a/src/layers/wav2vec2.cc b/src/layers/wav2vec2.cc index defbf0d84..a8f64649d 100644 --- a/src/layers/wav2vec2.cc +++ b/src/layers/wav2vec2.cc @@ -46,14 +46,7 @@ namespace ctranslate2 { } Wav2Vec2Encoder::Wav2Vec2Encoder(const models::Model& model, const std::string& scope) - : _feat_layer0(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0) - , _feat_layers(build_layers_list(model, - scope + "/feat_layer", - /*stride=*/2, - /*padding=*/0)) - , _fp_norm(model, scope + "/fp_layer_norm") - , _fp_ff(model, scope + "/fp_projection", nullptr, true) - , _pos_conv_embed(model, scope + "/pos_conv_embed") + : _upgraded_model(model.get_variable_if_exists(scope + "/lm_head/weight")) , _num_heads(model.get_attribute_with_default(scope + "/num_heads", 8)) , _transpose({0, 2, 1}) , _layers(build_layers_list(model, @@ -62,8 +55,18 @@ namespace ctranslate2 { /*pre_norm=*/true, ops::ActivationType::GELU)) , _output_norm(model, scope + "/layer_norm") - , _lm_head(model, scope + "/lm_head", nullptr, true) { + if (_upgraded_model) { + _feat_layer0.emplace(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0); + _feat_layers.emplace(build_layers_list(model, + scope + "/feat_layer", + /*stride=*/2, + /*padding=*/0)); + _fp_norm.emplace(model, scope + "/fp_layer_norm"); + _fp_ff.emplace(model, scope + "/fp_projection", nullptr, true); + _pos_conv_embed.emplace(model, scope + "/pos_conv_embed"); + _lm_head.emplace(model, scope + "/lm_head", nullptr, true); + } } void Wav2Vec2Encoder::operator()(const StorageView& features, StorageView& output) { @@ -74,33 +77,44 @@ namespace ctranslate2 { throw std::invalid_argument("Expected input features to have 3 dimensions, but got " + std::to_string(features.rank()) + " dimension(s) instead"); - - // Wav2Vec2FeatureExtractor------------------------------------ - StorageView feat_buffer(features.dtype(), features.device()); - StorageView feat_buffer2(features.dtype(), features.device()); - feat_buffer = std::move(features); - _feat_layer0(feat_buffer, output); - feat_buffer = std::move(output); - for (dim_t l = 0; l < _feat_layers.size(); l++) { - (*_feat_layers[l])(feat_buffer, output); - if (l < _feat_layers.size() - 1 ) { - feat_buffer = std::move(output); + if (_upgraded_model) { + // Wav2Vec2FeatureExtractor------------------------------------ + StorageView feat_buffer(features.dtype(), features.device()); + StorageView feat_buffer2(features.dtype(), features.device()); + feat_buffer = std::move(features); + (*_feat_layer0)(feat_buffer, output); //_feat_layer0(feat_buffer, output); + feat_buffer = std::move(output); + for (dim_t l = 0; l < _feat_layers->size(); l++) { + (*_feat_layers.value()[l])(feat_buffer, output); + if (l < _feat_layers->size() - 1 ) { + feat_buffer = std::move(output); + } } + _transpose(output, feat_buffer); + // Wav2Vec2FeatureProjection----------------------------------- + (*_fp_norm)(feat_buffer, output); //_fp_norm(feat_buffer, output); + (*_fp_ff)(output, feat_buffer); //_fp_ff(output, feat_buffer); + // Wav2Vec2PositionalConvEmbedding----------------------------- + (*_pos_conv_embed)(feat_buffer, feat_buffer2); //_pos_conv_embed(feat_buffer, feat_buffer2); + // Wav2Vec2EncoderLayerStableLayerNorm------------------------- + for (const auto& layer : _layers) { + (*layer)(feat_buffer2, nullptr, feat_buffer); + feat_buffer2 = std::move(feat_buffer); + } + _output_norm(feat_buffer2, feat_buffer); + + (*_lm_head)(feat_buffer, output); //_lm_head(feat_buffer, output); } - _transpose(output, feat_buffer); - // Wav2Vec2FeatureProjection----------------------------------- - _fp_norm(feat_buffer, output); - _fp_ff(output, feat_buffer); - // Wav2Vec2PositionalConvEmbedding----------------------------- - _pos_conv_embed(feat_buffer, feat_buffer2); - // Wav2Vec2EncoderLayerStableLayerNorm------------------------- - for (const auto& layer : _layers) { - (*layer)(feat_buffer2, nullptr, feat_buffer); - feat_buffer2 = std::move(feat_buffer); - } - _output_norm(feat_buffer2, feat_buffer); + else { // backward compatibility for the previous converted model + StorageView input(output_type(), features.device()); + input = features; + for (const auto& layer : _layers) { + (*layer)(input, nullptr, output); + input = std::move(output); + } - _lm_head(feat_buffer, output); + _output_norm(input, output); + } } } From e14fffb59fa5bb30d3160577340bfd568ee30397 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf Date: Tue, 19 Nov 2024 10:58:26 +0200 Subject: [PATCH 2/4] Update deprecated actions and images in CI (#1815) * update action versions * use more multi cpu in testing * fix uploading and downloading artifacts * more fixes to download artifacts * upgrade `macos` to 13 * remove `xdist` * better artifact names --- .github/workflows/ci.yml | 46 +++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c6f376231..82fcfe6d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: backend: [mkl, dnnl] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive @@ -82,7 +82,7 @@ jobs: backend: [openblas] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive @@ -137,11 +137,11 @@ jobs: include: - os: ubuntu-20.04 arch: aarch64 - - os: macos-12 + - os: macos-13 arch: arm64 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive @@ -150,7 +150,7 @@ jobs: name: Set up QEMU - name: Build wheels - uses: pypa/cibuildwheel@v2.16.5 + uses: pypa/cibuildwheel@v2.21.3 with: package-dir: python output-dir: python/wheelhouse @@ -168,9 +168,9 @@ jobs: CIBW_SKIP: pp* *-musllinux_* - name: Upload Python wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: python-wheels + name: python-wheels-${{ runner.os }}-${{ matrix.arch }} path: python/wheelhouse @@ -185,11 +185,11 @@ jobs: steps: - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.8 - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Prepare test environment shell: bash @@ -197,9 +197,11 @@ jobs: ./python/tools/prepare_test_environment.sh - name: Download Python wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: - name: python-wheels + pattern: python-wheels-${{ runner.os }}-* + merge-multiple: true + path: . - name: Install wheel if: startsWith(matrix.os, 'ubuntu') @@ -222,10 +224,10 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.8 @@ -257,9 +259,11 @@ jobs: steps: - name: Download Python wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: - name: python-wheels + pattern: python-wheels-* + merge-multiple: true + path: . - name: Publish Python wheels to PyPI uses: pypa/gh-action-pypi-publish@release/v1 @@ -272,7 +276,7 @@ jobs: build-and-push-docker-images: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive @@ -299,17 +303,19 @@ jobs: needs: [check-python-style, build-python-wheels] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.8 - name: Download CTranslate2 wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: - name: python-wheels + pattern: python-wheels-${{ runner.os }}-* + merge-multiple: true + path: . - name: Install CTranslate2 wheel run: | From 84609939e953bd21c20a28c440d4dbae503d47b0 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf Date: Mon, 25 Nov 2024 11:12:12 +0200 Subject: [PATCH 3/4] prevent double library def (#1818) --- CMakeLists.txt | 62 +++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 62fc33640..983dd7041 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -551,39 +551,9 @@ if (WITH_CUDA) else() list(APPEND LIBRARIES ${CUDA_CUBLAS_LIBRARIES}) endif() - set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_library(${PROJECT_NAME} - ${SOURCES} - src/cuda/allocator.cc - src/cuda/primitives.cu - src/cuda/random.cu - src/cuda/utils.cc - src/ops/alibi_add_gpu.cu - src/ops/bias_add_gpu.cu - src/ops/concat_split_slide_gpu.cu - src/ops/conv1d_gpu.cu - src/ops/dequantize_gpu.cu - src/ops/flash_attention_gpu.cu - src/ops/gather_gpu.cu - src/ops/gumbel_max_gpu.cu - src/ops/layer_norm_gpu.cu - src/ops/mean_gpu.cu - src/ops/multinomial_gpu.cu - src/ops/rms_norm_gpu.cu - src/ops/rotary_gpu.cu - src/ops/softmax_gpu.cu - src/ops/tile_gpu.cu - src/ops/topk_gpu.cu - src/ops/topp_mask_gpu.cu - src/ops/quantize_gpu.cu - src/ops/nccl_ops_gpu.cu - src/ops/awq/gemm_gpu.cu - src/ops/awq/gemv_gpu.cu - src/ops/awq/dequantize_gpu.cu - ) if (WITH_FLASH_ATTN) add_definitions(-DCT2_WITH_FLASH_ATTN) - cuda_add_library(${PROJECT_NAME} + list(APPEND SOURCES src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu @@ -653,6 +623,36 @@ if (WITH_CUDA) src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu PROPERTIES COMPILE_FLAGS "--use_fast_math") endif() + set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) + cuda_add_library(${PROJECT_NAME} + ${SOURCES} + src/cuda/allocator.cc + src/cuda/primitives.cu + src/cuda/random.cu + src/cuda/utils.cc + src/ops/alibi_add_gpu.cu + src/ops/bias_add_gpu.cu + src/ops/concat_split_slide_gpu.cu + src/ops/conv1d_gpu.cu + src/ops/dequantize_gpu.cu + src/ops/flash_attention_gpu.cu + src/ops/gather_gpu.cu + src/ops/gumbel_max_gpu.cu + src/ops/layer_norm_gpu.cu + src/ops/mean_gpu.cu + src/ops/multinomial_gpu.cu + src/ops/rms_norm_gpu.cu + src/ops/rotary_gpu.cu + src/ops/softmax_gpu.cu + src/ops/tile_gpu.cu + src/ops/topk_gpu.cu + src/ops/topp_mask_gpu.cu + src/ops/quantize_gpu.cu + src/ops/nccl_ops_gpu.cu + src/ops/awq/gemm_gpu.cu + src/ops/awq/gemv_gpu.cu + src/ops/awq/dequantize_gpu.cu + ) elseif(WITH_CUDNN) From 2870fe3ddce49c85ecab4f84fc5e4b01b3a740fe Mon Sep 17 00:00:00 2001 From: Minh-Thuc <46375464+minhthuc2502@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:37:42 +0100 Subject: [PATCH 4/4] Support qwen2 (#1820) * support qwen2 * fix flake --- README.md | 2 +- python/ctranslate2/converters/transformers.py | 108 ++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bfb64c851..fb91f5eb3 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ The project implements a custom runtime that applies many performance optimizati The following model types are currently supported: * Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper -* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon +* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2 * Encoder-only models: BERT, DistilBERT, XLM-RoBERTa Compatible models should be first converted into an optimized model format. The library includes converters for multiple frameworks: diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 7655662dd..d5f935f95 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1956,6 +1956,114 @@ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2): gc.collect() +@register_loader("Qwen2Config") +class Qwen2Loader(ModelLoader): + @property + def architecture_name(self): + return "Qwen2ForCausalLM" + + def get_model_spec(self, model): + num_layers = model.config.num_hidden_layers + + num_heads = model.config.num_attention_heads + num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads) + if num_heads_kv == num_heads: + num_heads_kv = None + + rope_scaling = getattr(model.config, "rope_scaling", None) + if rope_scaling: + rope_type = rope_scaling.get("type") or rope_scaling["rope_type"] + rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type) + rotary_scaling_factor = rope_scaling["factor"] + + if rotary_scaling_type is None: + raise NotImplementedError( + "RoPE scaling type '%s' is not yet implemented. " + "The following RoPE scaling types are currently supported: %s" + % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys())) + ) + else: + rotary_scaling_type = None + rotary_scaling_factor = 1 + + spec = transformer_spec.TransformerDecoderModelSpec.from_config( + num_layers, + num_heads, + activation=common_spec.Activation.SWISH, + pre_norm=True, + ffn_glu=True, + rms_norm=True, + rotary_dim=0, + rotary_interleave=False, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=getattr(model.config, "rope_theta", 10000), + num_heads_kv=num_heads_kv, + ) + + self.set_decoder(spec.decoder, model.model) + self.set_linear(spec.decoder.projection, model.lm_head) + return spec + + def get_vocabulary(self, model, tokenizer): + tokens = super().get_vocabulary(model, tokenizer) + + extra_ids = model.config.vocab_size - len(tokens) + for i in range(extra_ids): + tokens.append("" % i) + return tokens + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = ( + tokenizer.bos_token + if tokenizer.bos_token is not None + else tokenizer.pad_token + ) + config.eos_token = tokenizer.eos_token + config.unk_token = ( + tokenizer.unk_token if tokenizer.unk_token is not None else "" + ) + config.layer_norm_epsilon = model.config.rms_norm_eps + + def set_layer_norm(self, spec, layer_norm): + spec.gamma = layer_norm.weight + + def set_decoder(self, spec, module): + spec.scale_embeddings = False + self.set_embeddings(spec.embeddings, module.embed_tokens) + self.set_layer_norm(spec.layer_norm, module.norm) + + for layer_spec, layer in zip(spec.layer, module.layers): + self.set_layer_norm( + layer_spec.self_attention.layer_norm, layer.input_layernorm + ) + self.set_layer_norm( + layer_spec.ffn.layer_norm, layer.post_attention_layernorm + ) + + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear(split_layers[0], layer.self_attn.q_proj) + self.set_linear(split_layers[1], layer.self_attn.k_proj) + self.set_linear(split_layers[2], layer.self_attn.v_proj) + + utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers) + self.set_linear( + layer_spec.self_attention.linear[1], + layer.self_attn.o_proj, + ) + + self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj) + self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj) + self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj) + + delattr(layer, "self_attn") + delattr(layer, "mlp") + gc.collect() + + @register_loader("MixFormerSequentialConfig") class MixFormerSequentialLoader(ModelLoader): @property