Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eoleconv #1832

Merged
merged 9 commits into from
Dec 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 352 additions & 0 deletions python/ctranslate2/converters/eole_ct2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
import argparse

from eole.config.run import PredictConfig
from eole.constants import PositionEncodingType
from eole.inputters.inputter import vocabs_to_dict
from eole.models.model import BaseModel

from ctranslate2.converters import utils
from ctranslate2.converters.converter import Converter
from ctranslate2.specs import common_spec, transformer_spec

_SUPPORTED_ACTIVATIONS = {
"gelu": common_spec.Activation.GELU,
"fast_gelu": common_spec.Activation.GELUTanh,
"relu": common_spec.Activation.RELU,
"gated-silu": common_spec.Activation.SWISH,
}


def _get_model_spec_seq2seq(
config, variables, src_vocabs, tgt_vocabs, num_source_embeddings
):
"""Creates a model specification from the model config."""
with_relative_position = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Relative
)
with_rotary = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Rotary
)
if with_rotary:
raise ValueError(
"Rotary embeddings are not supported yet for encoder/decoder models"
)
with_alibi = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Alibi
)
if with_alibi:
raise ValueError("Alibi is not supported yet for encoder/decoder models")
activation_fn = getattr(config, "mlp_activation_fn", "relu")

# Return the first head of the last layer unless the model was trained with alignments.
if getattr(config.decoder, "lambda_align", 0) == 0:
alignment_layer = -1
alignment_heads = 1
else:
alignment_layer = config.decoder.alignment_layer
alignment_heads = config.decoder.alignment_heads

num_heads = getattr(config.decoder, "heads", 8)
# num_kv = getattr(config.decoder, "heads_kv", 0)
# if num_kv == num_heads or num_kv == 0:
# num_kv = None
# rotary_dim = 0 if with_rotary else None
# rotary_interleave = getattr(config.rope_config, "rotary_interleave", True)
ffn_glu = activation_fn == "gated-silu"
sliding_window = getattr(config, "sliding_window", 0)
if sliding_window != 0:
raise ValueError(
"Sliding window is not suported yet for encoder/decoder models"
)

model_spec = transformer_spec.TransformerSpec.from_config(
(config.encoder.layers, config.decoder.layers),
num_heads,
with_relative_position=with_relative_position,
# alibi=with_alibi,
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
ffn_glu=ffn_glu,
rms_norm=config.layer_norm == "rms",
# rotary_dim=rotary_dim,
# rotary_interleave=rotary_interleave,
# num_heads_kv=num_kv,
# sliding_window=sliding_window,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
num_source_embeddings=num_source_embeddings,
# multi_query_attention=getattr(opt, "multiquery", False),
)

set_transformer_spec(model_spec, variables)
for src_vocab in src_vocabs:
model_spec.register_source_vocabulary(src_vocab)
for tgt_vocab in tgt_vocabs:
model_spec.register_target_vocabulary(tgt_vocab)

return model_spec


def _get_model_spec_lm(
config, variables, src_vocabs, tgt_vocabs, num_source_embeddings
):
"""Creates a model specification from the model config."""
with_relative_position = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Relative
)
with_rotary = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Rotary
)
with_alibi = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Alibi
)
activation_fn = getattr(config, "mlp_activation_fn", "relu")
num_heads = getattr(config.decoder, "heads", 8)
num_kv = getattr(config.decoder, "heads_kv", 0)
if num_kv == num_heads or num_kv == 0:
num_kv = None
rotary_dim = 0 if with_rotary else None
rotary_interleave = getattr(config.rope_config, "rotary_interleave", True)
ffn_glu = activation_fn == "gated-silu"
sliding_window = getattr(config, "sliding_window", 0)

model_spec = transformer_spec.TransformerDecoderModelSpec.from_config(
config.decoder.layers,
num_heads,
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
ffn_glu=ffn_glu,
with_relative_position=with_relative_position,
alibi=with_alibi,
rms_norm=config.layer_norm == "rms",
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
num_heads_kv=num_kv,
sliding_window=sliding_window,
# multi_query_attention=getattr(opt, "multiquery", False),
)

set_transformer_decoder(
model_spec.decoder,
variables,
with_encoder_attention=False,
)

for tgt_vocab in tgt_vocabs:
model_spec.register_vocabulary(tgt_vocab)

return model_spec


def get_vocabs(vocab):
src_vocabs = [vocab["src"]]
tgt_vocabs = [vocab["tgt"]]
return src_vocabs, tgt_vocabs


class EoleConverter(Converter):
"""Converts models generated by OpenNMT-py."""

def __init__(self, model_path: str):
"""Initializes the OpenNMT-py converter.

Arguments:
model_path: Path to the OpenNMT-py PyTorch model (.pt file).
"""
self._model_path = model_path

def _load(self):
import torch

config = PredictConfig(model_path=self._model_path, src="dummy")

vocabs, model, model_config = BaseModel.load_test_model(config)
vocabs_dict = vocabs_to_dict(vocabs)

config.model = model_config
src_vocabs, tgt_vocabs = get_vocabs(vocabs_dict)

if config.model.decoder.decoder_type == "transformer_lm":
spec = _get_model_spec_lm(
config.model,
model.state_dict(),
src_vocabs,
tgt_vocabs,
num_source_embeddings=len(src_vocabs),
)
else:
spec = _get_model_spec_seq2seq(
config.model,
model.state_dict(),
src_vocabs,
tgt_vocabs,
num_source_embeddings=len(src_vocabs),
)
spec.config.decoder_start_token = vocabs["decoder_start_token"]

spec.config.bos_token = vocabs["specials"]["bos_token"]
spec.config.eos_token = vocabs["specials"]["eos_token"]
spec.config.unk_token = vocabs["specials"]["unk_token"]
spec.config.layer_norm_epsilon = getattr(config, "norm_eps", 1e-6)

return spec


def set_transformer_spec(spec, variables):
set_transformer_encoder(spec.encoder, variables)
set_transformer_decoder(spec.decoder, variables)


def set_transformer_encoder(spec, variables):
set_input_layers(spec, variables, "src_emb")
set_layer_norm(spec.layer_norm, variables, "encoder.layer_norm")
for i, layer in enumerate(spec.layer):
set_transformer_encoder_layer(
layer, variables, "encoder.transformer_layers.%d" % i
)


def set_transformer_decoder(spec, variables, with_encoder_attention=True):
set_input_layers(spec, variables, "tgt_emb")
set_layer_norm(spec.layer_norm, variables, "decoder.layer_norm")
for i, layer in enumerate(spec.layer):
set_transformer_decoder_layer(
layer,
variables,
"decoder.transformer_layers.%d" % i,
with_encoder_attention=with_encoder_attention,
)

set_linear(spec.projection, variables, "generator")


def set_input_layers(spec, variables, scope):
if hasattr(spec, "position_encodings"):
set_position_encodings(
spec.position_encodings,
variables,
"%s.pe" % scope,
)
else:
spec.scale_embeddings = False

embeddings_specs = spec.embeddings
# encoder embeddings are stored in a list(onmt/ct2 legacy with features)
if isinstance(embeddings_specs, list):
embeddings_specs = embeddings_specs[0]
set_embeddings(embeddings_specs, variables, "%s.embeddings" % scope)


def set_transformer_encoder_layer(spec, variables, scope):
set_multi_head_attention(
spec.self_attention,
variables,
"%s.self_attn" % scope,
self_attention=True,
)
set_layer_norm(
spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope
)
set_layer_norm(
spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope
)
set_ffn(spec.ffn, variables, "%s.mlp" % scope)


def set_transformer_decoder_layer(spec, variables, scope, with_encoder_attention=True):
set_multi_head_attention(
spec.self_attention,
variables,
"%s.self_attn" % scope,
self_attention=True,
)
set_layer_norm(
spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope
)
if with_encoder_attention:
set_multi_head_attention(spec.attention, variables, "%s.context_attn" % scope)
set_layer_norm(
spec.attention.layer_norm, variables, "%s.precontext_layernorm" % scope
)
set_layer_norm(
spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope
)
set_ffn(spec.ffn, variables, "%s.mlp" % scope)


def set_ffn(spec, variables, scope):
set_linear(spec.linear_0, variables, "%s.gate_up_proj" % scope)
set_linear(spec.linear_1, variables, "%s.down_proj" % scope)
if hasattr(spec, "linear_0_noact"):
set_linear(spec.linear_0_noact, variables, "%s.up_proj" % scope)


def set_multi_head_attention(spec, variables, scope, self_attention=False):
if self_attention:
split_layers = [common_spec.LinearSpec() for _ in range(3)]
set_linear(split_layers[0], variables, "%s.linear_query" % scope)
set_linear(split_layers[1], variables, "%s.linear_keys" % scope)
set_linear(split_layers[2], variables, "%s.linear_values" % scope)
utils.fuse_linear(spec.linear[0], split_layers)
else:
set_linear(spec.linear[0], variables, "%s.linear_query" % scope)
split_layers = [common_spec.LinearSpec() for _ in range(2)]
set_linear(split_layers[0], variables, "%s.linear_keys" % scope)
set_linear(split_layers[1], variables, "%s.linear_values" % scope)
utils.fuse_linear(spec.linear[1], split_layers)
set_linear(spec.linear[-1], variables, "%s.final_linear" % scope)
if hasattr(spec, "relative_position_keys"):
spec.relative_position_keys = _get_variable(
variables, "%s.relative_positions_embeddings.weight" % scope
)
spec.relative_position_values = spec.relative_position_keys


def set_layer_norm(spec, variables, scope):
try:
spec.gamma = _get_variable(variables, "%s.weight" % scope)
except KeyError:
# Compatibility with older models using a custom LayerNorm module.
spec.gamma = _get_variable(variables, "%s.a_2" % scope)
spec.beta = _get_variable(variables, "%s.b_2" % scope)
try:
spec.beta = _get_variable(variables, "%s.bias" % scope)
except KeyError:
pass


def set_linear(spec, variables, scope):
spec.weight = _get_variable(variables, "%s.weight" % scope)
bias = variables.get("%s.bias" % scope)
if bias is not None:
spec.bias = bias


def set_embeddings(spec, variables, scope):
spec.weight = _get_variable(variables, "%s.weight" % scope)


def set_position_encodings(spec, variables, scope):
spec.encodings = _get_variable(variables, "%s.pe" % scope).squeeze()


def _get_variable(variables, name):
return variables[name]


def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--model_path", required=True, help="Model path.")
Converter.declare_arguments(parser)
args = parser.parse_args()
EoleConverter(args.model_path).convert_from_args(args)


if __name__ == "__main__":
main()
Loading