-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StopStringCriteria relies on len(tokenizer)==model.config.vocab_size
, leading to index errors
#35244
Comments
Hi @Kripner - yes, this does look like a bug, although it should only be triggered rarely. I think the simplest solution would be to extend the stop string embedding matrix to |
Hi @Rocketknight1, this might not be possible because |
@Kripner Sounds good, but to keep static shapes for better compilation, how about:
I'm more familiar with XLA, so I don't know how much the torch compiler depends on static shapes, but if there's a static shape solution I think we should use it regardless! |
@Rocketknight1 This looks great to me! |
@Kripner Cool! Would you be willing to make that PR? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
This fixes #35244 by clipping token IDs to be within the tokenizer's vocabulary size before performing the embedding lookup. This prevents index errors when model.config.vocab_size > len(tokenizer). The fix: 1. Adds a clamp operation to ensure token IDs are within bounds 2. Adds a test case to verify the behavior
System Info
Python: 3.12.0
Transformers: 4.46.3
Who can help?
@gante
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
After fine-tuning EleutherAI/pythia-14m using transformer's Trainer, I run inference like this:
Note that
tokenizer.pad_token_id
has to be set explicitly because it is not present in Pythia'sspecial_tokens_map.json
. This code leads to the following error (run withCUDA_LAUNCH_BLOCKING=1
):This is due to mismatch between
len(tokenizer)
(50277) andmodel.config.vocab_size
(50304 or 50432). This decision to round up the size of the embedding matrix to the next multiple of 128 or 256 was presumably made due to efficiency reasons. However, during sampling, tokens abovelen(tokenizer)
can sometimes be generated. This is silently ignored by the tokenizer, converting such tokens to empty string. However,StopStringCriteria
is implemented by indexing into an embedding with size determined bylen(tokenizer)
and therefore fails when it encounters a higher token.A temporary fix is to explicitly suppress the unknown tokens from being generated:
I propose that a more principled solution would to be modify
StopStringCriteria
to ignore tokens abovelen(tokenizer)
.Expected behavior
Expected behavior of the
generate
method is to not fail.The text was updated successfully, but these errors were encountered: