Skip to content

Commit

Permalink
Eoleconv (#1832)
Browse files Browse the repository at this point in the history
* Add eole to ct2 converter
  • Loading branch information
vince62s authored Dec 18, 2024
1 parent 2870fe3 commit 8dee2a2
Showing 1 changed file with 352 additions and 0 deletions.
352 changes: 352 additions & 0 deletions python/ctranslate2/converters/eole_ct2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
import argparse

from eole.config.run import PredictConfig
from eole.constants import PositionEncodingType
from eole.inputters.inputter import vocabs_to_dict
from eole.models.model import BaseModel

from ctranslate2.converters import utils
from ctranslate2.converters.converter import Converter
from ctranslate2.specs import common_spec, transformer_spec

_SUPPORTED_ACTIVATIONS = {
"gelu": common_spec.Activation.GELU,
"fast_gelu": common_spec.Activation.GELUTanh,
"relu": common_spec.Activation.RELU,
"gated-silu": common_spec.Activation.SWISH,
}


def _get_model_spec_seq2seq(
config, variables, src_vocabs, tgt_vocabs, num_source_embeddings
):
"""Creates a model specification from the model config."""
with_relative_position = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Relative
)
with_rotary = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Rotary
)
if with_rotary:
raise ValueError(
"Rotary embeddings are not supported yet for encoder/decoder models"
)
with_alibi = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Alibi
)
if with_alibi:
raise ValueError("Alibi is not supported yet for encoder/decoder models")
activation_fn = getattr(config, "mlp_activation_fn", "relu")

# Return the first head of the last layer unless the model was trained with alignments.
if getattr(config.decoder, "lambda_align", 0) == 0:
alignment_layer = -1
alignment_heads = 1
else:
alignment_layer = config.decoder.alignment_layer
alignment_heads = config.decoder.alignment_heads

num_heads = getattr(config.decoder, "heads", 8)
# num_kv = getattr(config.decoder, "heads_kv", 0)
# if num_kv == num_heads or num_kv == 0:
# num_kv = None
# rotary_dim = 0 if with_rotary else None
# rotary_interleave = getattr(config.rope_config, "rotary_interleave", True)
ffn_glu = activation_fn == "gated-silu"
sliding_window = getattr(config, "sliding_window", 0)
if sliding_window != 0:
raise ValueError(
"Sliding window is not suported yet for encoder/decoder models"
)

model_spec = transformer_spec.TransformerSpec.from_config(
(config.encoder.layers, config.decoder.layers),
num_heads,
with_relative_position=with_relative_position,
# alibi=with_alibi,
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
ffn_glu=ffn_glu,
rms_norm=config.layer_norm == "rms",
# rotary_dim=rotary_dim,
# rotary_interleave=rotary_interleave,
# num_heads_kv=num_kv,
# sliding_window=sliding_window,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
num_source_embeddings=num_source_embeddings,
# multi_query_attention=getattr(opt, "multiquery", False),
)

set_transformer_spec(model_spec, variables)
for src_vocab in src_vocabs:
model_spec.register_source_vocabulary(src_vocab)
for tgt_vocab in tgt_vocabs:
model_spec.register_target_vocabulary(tgt_vocab)

return model_spec


def _get_model_spec_lm(
config, variables, src_vocabs, tgt_vocabs, num_source_embeddings
):
"""Creates a model specification from the model config."""
with_relative_position = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Relative
)
with_rotary = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Rotary
)
with_alibi = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Alibi
)
activation_fn = getattr(config, "mlp_activation_fn", "relu")
num_heads = getattr(config.decoder, "heads", 8)
num_kv = getattr(config.decoder, "heads_kv", 0)
if num_kv == num_heads or num_kv == 0:
num_kv = None
rotary_dim = 0 if with_rotary else None
rotary_interleave = getattr(config.rope_config, "rotary_interleave", True)
ffn_glu = activation_fn == "gated-silu"
sliding_window = getattr(config, "sliding_window", 0)

model_spec = transformer_spec.TransformerDecoderModelSpec.from_config(
config.decoder.layers,
num_heads,
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
ffn_glu=ffn_glu,
with_relative_position=with_relative_position,
alibi=with_alibi,
rms_norm=config.layer_norm == "rms",
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
num_heads_kv=num_kv,
sliding_window=sliding_window,
# multi_query_attention=getattr(opt, "multiquery", False),
)

set_transformer_decoder(
model_spec.decoder,
variables,
with_encoder_attention=False,
)

for tgt_vocab in tgt_vocabs:
model_spec.register_vocabulary(tgt_vocab)

return model_spec


def get_vocabs(vocab):
src_vocabs = [vocab["src"]]
tgt_vocabs = [vocab["tgt"]]
return src_vocabs, tgt_vocabs


class EoleConverter(Converter):
"""Converts models generated by OpenNMT-py."""

def __init__(self, model_path: str):
"""Initializes the OpenNMT-py converter.
Arguments:
model_path: Path to the OpenNMT-py PyTorch model (.pt file).
"""
self._model_path = model_path

def _load(self):
import torch

config = PredictConfig(model_path=self._model_path, src="dummy")

vocabs, model, model_config = BaseModel.load_test_model(config)
vocabs_dict = vocabs_to_dict(vocabs)

config.model = model_config
src_vocabs, tgt_vocabs = get_vocabs(vocabs_dict)

if config.model.decoder.decoder_type == "transformer_lm":
spec = _get_model_spec_lm(
config.model,
model.state_dict(),
src_vocabs,
tgt_vocabs,
num_source_embeddings=len(src_vocabs),
)
else:
spec = _get_model_spec_seq2seq(
config.model,
model.state_dict(),
src_vocabs,
tgt_vocabs,
num_source_embeddings=len(src_vocabs),
)
spec.config.decoder_start_token = vocabs["decoder_start_token"]

spec.config.bos_token = vocabs["specials"]["bos_token"]
spec.config.eos_token = vocabs["specials"]["eos_token"]
spec.config.unk_token = vocabs["specials"]["unk_token"]
spec.config.layer_norm_epsilon = getattr(config, "norm_eps", 1e-6)

return spec


def set_transformer_spec(spec, variables):
set_transformer_encoder(spec.encoder, variables)
set_transformer_decoder(spec.decoder, variables)


def set_transformer_encoder(spec, variables):
set_input_layers(spec, variables, "src_emb")
set_layer_norm(spec.layer_norm, variables, "encoder.layer_norm")
for i, layer in enumerate(spec.layer):
set_transformer_encoder_layer(
layer, variables, "encoder.transformer_layers.%d" % i
)


def set_transformer_decoder(spec, variables, with_encoder_attention=True):
set_input_layers(spec, variables, "tgt_emb")
set_layer_norm(spec.layer_norm, variables, "decoder.layer_norm")
for i, layer in enumerate(spec.layer):
set_transformer_decoder_layer(
layer,
variables,
"decoder.transformer_layers.%d" % i,
with_encoder_attention=with_encoder_attention,
)

set_linear(spec.projection, variables, "generator")


def set_input_layers(spec, variables, scope):
if hasattr(spec, "position_encodings"):
set_position_encodings(
spec.position_encodings,
variables,
"%s.pe" % scope,
)
else:
spec.scale_embeddings = False

embeddings_specs = spec.embeddings
# encoder embeddings are stored in a list(onmt/ct2 legacy with features)
if isinstance(embeddings_specs, list):
embeddings_specs = embeddings_specs[0]
set_embeddings(embeddings_specs, variables, "%s.embeddings" % scope)


def set_transformer_encoder_layer(spec, variables, scope):
set_multi_head_attention(
spec.self_attention,
variables,
"%s.self_attn" % scope,
self_attention=True,
)
set_layer_norm(
spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope
)
set_layer_norm(
spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope
)
set_ffn(spec.ffn, variables, "%s.mlp" % scope)


def set_transformer_decoder_layer(spec, variables, scope, with_encoder_attention=True):
set_multi_head_attention(
spec.self_attention,
variables,
"%s.self_attn" % scope,
self_attention=True,
)
set_layer_norm(
spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope
)
if with_encoder_attention:
set_multi_head_attention(spec.attention, variables, "%s.context_attn" % scope)
set_layer_norm(
spec.attention.layer_norm, variables, "%s.precontext_layernorm" % scope
)
set_layer_norm(
spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope
)
set_ffn(spec.ffn, variables, "%s.mlp" % scope)


def set_ffn(spec, variables, scope):
set_linear(spec.linear_0, variables, "%s.gate_up_proj" % scope)
set_linear(spec.linear_1, variables, "%s.down_proj" % scope)
if hasattr(spec, "linear_0_noact"):
set_linear(spec.linear_0_noact, variables, "%s.up_proj" % scope)


def set_multi_head_attention(spec, variables, scope, self_attention=False):
if self_attention:
split_layers = [common_spec.LinearSpec() for _ in range(3)]
set_linear(split_layers[0], variables, "%s.linear_query" % scope)
set_linear(split_layers[1], variables, "%s.linear_keys" % scope)
set_linear(split_layers[2], variables, "%s.linear_values" % scope)
utils.fuse_linear(spec.linear[0], split_layers)
else:
set_linear(spec.linear[0], variables, "%s.linear_query" % scope)
split_layers = [common_spec.LinearSpec() for _ in range(2)]
set_linear(split_layers[0], variables, "%s.linear_keys" % scope)
set_linear(split_layers[1], variables, "%s.linear_values" % scope)
utils.fuse_linear(spec.linear[1], split_layers)
set_linear(spec.linear[-1], variables, "%s.final_linear" % scope)
if hasattr(spec, "relative_position_keys"):
spec.relative_position_keys = _get_variable(
variables, "%s.relative_positions_embeddings.weight" % scope
)
spec.relative_position_values = spec.relative_position_keys


def set_layer_norm(spec, variables, scope):
try:
spec.gamma = _get_variable(variables, "%s.weight" % scope)
except KeyError:
# Compatibility with older models using a custom LayerNorm module.
spec.gamma = _get_variable(variables, "%s.a_2" % scope)
spec.beta = _get_variable(variables, "%s.b_2" % scope)
try:
spec.beta = _get_variable(variables, "%s.bias" % scope)
except KeyError:
pass


def set_linear(spec, variables, scope):
spec.weight = _get_variable(variables, "%s.weight" % scope)
bias = variables.get("%s.bias" % scope)
if bias is not None:
spec.bias = bias


def set_embeddings(spec, variables, scope):
spec.weight = _get_variable(variables, "%s.weight" % scope)


def set_position_encodings(spec, variables, scope):
spec.encodings = _get_variable(variables, "%s.pe" % scope).squeeze()


def _get_variable(variables, name):
return variables[name]


def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--model_path", required=True, help="Model path.")
Converter.declare_arguments(parser)
args = parser.parse_args()
EoleConverter(args.model_path).convert_from_args(args)


if __name__ == "__main__":
main()

0 comments on commit 8dee2a2

Please sign in to comment.