Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Dec 17, 2024
1 parent 64b9803 commit 7929f8b
Showing 1 changed file with 12 additions and 22 deletions.
34 changes: 12 additions & 22 deletions python/ctranslate2/converters/eole_ct2.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def _get_model_spec_seq2seq(
# multi_query_attention=getattr(opt, "multiquery", False),
)

# model_spec.config.decoder_start_token = getattr(opt, "decoder_start_token", "<s>")

set_transformer_spec(model_spec, variables)
for src_vocab in src_vocabs:
model_spec.register_source_vocabulary(src_vocab)
Expand Down Expand Up @@ -132,8 +130,6 @@ def _get_model_spec_lm(
# multi_query_attention=getattr(opt, "multiquery", False),
)

model_spec.config.layer_norm_epsilon = getattr(config, "norm_eps", 1e-6)

set_transformer_decoder(
model_spec.decoder,
variables,
Expand All @@ -147,22 +143,8 @@ def _get_model_spec_lm(


def get_vocabs(vocab):
if isinstance(vocab, dict) and "src" in vocab:
if isinstance(vocab["src"], list):
src_vocabs = [vocab["src"]]
tgt_vocabs = [vocab["tgt"]]

src_feats = vocab.get("src_feats")
if src_feats is not None:
src_vocabs.extend(src_feats.values())
else:
src_vocabs = [field[1].vocab.itos for field in vocab["src"].fields]
tgt_vocabs = [field[1].vocab.itos for field in vocab["tgt"].fields]
else:
# Compatibility with older models.
src_vocabs = [vocab[0][1].itos]
tgt_vocabs = [vocab[1][1].itos]

src_vocabs = [vocab["src"]]
tgt_vocabs = [vocab["tgt"]]
return src_vocabs, tgt_vocabs


Expand All @@ -189,21 +171,29 @@ def _load(self):
src_vocabs, tgt_vocabs = get_vocabs(vocabs_dict)

if config.model.decoder.decoder_type == "transformer_lm":
return _get_model_spec_lm(
spec = _get_model_spec_lm(
config.model,
model.state_dict(),
src_vocabs,
tgt_vocabs,
num_source_embeddings=len(src_vocabs),
)
else:
return _get_model_spec_seq2seq(
spec = _get_model_spec_seq2seq(
config.model,
model.state_dict(),
src_vocabs,
tgt_vocabs,
num_source_embeddings=len(src_vocabs),
)
spec.config.decoder_start_token = vocabs["decoder_start_token"]

spec.config.bos_token = vocabs["specials"]["bos_token"]
spec.config.eos_token = vocabs["specials"]["eos_token"]
spec.config.unk_token = vocabs["specials"]["unk_token"]
spec.config.layer_norm_epsilon = getattr(config, "norm_eps", 1e-6)

return spec


def set_transformer_spec(spec, variables):
Expand Down

0 comments on commit 7929f8b

Please sign in to comment.