From a25517457c82e8707574e9eb6e6041d78f798fe4 Mon Sep 17 00:00:00 2001 From: Chii Yeh Date: Fri, 3 Nov 2023 13:36:49 +0800 Subject: [PATCH] Support conversion for distil-whisper Previous code was assuming same number of encoder and decoder layer. Removed this assumptions and obtain the number of layer separately. --- python/ctranslate2/converters/transformers.py | 2 ++ python/ctranslate2/specs/whisper_spec.py | 20 ++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 2c8138da0..38aeab296 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -878,6 +878,8 @@ def get_model_spec(self, model): spec = whisper_spec.WhisperSpec( model.config.encoder_layers, model.config.encoder_attention_heads, + model.config.decoder_layers, + model.config.decoder_attention_heads, ) self.set_encoder(spec.encoder, model.model.encoder) diff --git a/python/ctranslate2/specs/whisper_spec.py b/python/ctranslate2/specs/whisper_spec.py index fffd5c0f8..e32453e1c 100644 --- a/python/ctranslate2/specs/whisper_spec.py +++ b/python/ctranslate2/specs/whisper_spec.py @@ -26,17 +26,27 @@ def __init__( class WhisperSpec(model_spec.LanguageModelSpec): """Describes a Whisper model.""" - def __init__(self, num_layers, num_heads): + def __init__( + self, + num_encoder_layers, + num_encoder_heads, + num_decoder_layers, + num_decoder_heads, + ): """Initializes the model specification. Args: - num_layers: The number of encoder and decoder layers. - num_heads: The number of attention heads. + num_encoder_layers: The number of encoder layers. + num_encoder_heads: The number of encoder attention heads. + num_decoder_layers: The number of decoder layers. + num_decoder_heads: The number of decoder attention heads. """ super().__init__() - self.encoder = WhisperEncoderSpec(num_layers, num_heads) + self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads) self.decoder = transformer_spec.TransformerDecoderSpec( - num_layers, num_heads, activation=common_spec.Activation.GELU + num_decoder_layers, + num_decoder_heads, + activation=common_spec.Activation.GELU, ) self.decoder.scale_embeddings = False