Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StopStringCriteria relies on len(tokenizer)==model.config.vocab_size, leading to index errors #35244

Closed
2 of 4 tasks
Kripner opened this issue Dec 12, 2024 · 6 comments · May be fixed by #35797
Closed
2 of 4 tasks

StopStringCriteria relies on len(tokenizer)==model.config.vocab_size, leading to index errors #35244

Kripner opened this issue Dec 12, 2024 · 6 comments · May be fixed by #35797
Labels

Comments

@Kripner
Copy link
Contributor

Kripner commented Dec 12, 2024

System Info

Python: 3.12.0
Transformers: 4.46.3

Who can help?

@gante
@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

After fine-tuning EleutherAI/pythia-14m using transformer's Trainer, I run inference like this:

checkpoint = "models/checkpoint-166000"
device = "cuda"

model = AutoModelForCausalLM.from_pretrained(checkpoint)
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(checkpoint, padding_side="left")
tokenizer.pad_token_id = 1
tokenizer.pad_token = "<|padding|>"

prompts = [
    "prompt1",
    "prompt2",
]
inputs = tokenizer(
    prompts, return_tensors="pt", padding=True, truncation=True, max_length=512,
)

gen_config = copy.deepcopy(model.generation_config)
gen_config.update(
    max_new_tokens=max_length,
    do_sample=True,
    top_k=0,
    pad_token_id=tokenizer.pad_token_id,
    stop_strings="end",
)
gen_config.validate()

outputs = model.generate(
    input_ids=inputs["input_ids"].to(device),
    attention_mask=inputs["attention_mask"].to(device),
    num_return_sequences=32,
    generation_config=gen_config,
    output_scores=True,
    return_dict_in_generate=True,
    tokenizer=tokenizer,
)

Note that tokenizer.pad_token_id has to be set explicitly because it is not present in Pythia's special_tokens_map.json. This code leads to the following error (run with CUDA_LAUNCH_BLOCKING=1):

../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [1,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [1,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [1,0,0], thread: [102,0,0] Assertion `srcIndex < srcSelectDimSize` failed.      
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [1,0,0], thread: [103,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):                   
  File "home/m/src/playground.py", line 43, in <module>                 
    outputs = model.generate(
              ^^^^^^^^^^^^^^^
  File "/home/m/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context                            
    return func(*args, **kwargs)             
           ^^^^^^^^^^^^^^^^^^^^^                        
  File "/home/m/venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2215, in generate                             
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/m/venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 3262, in _sample                              
    unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
                                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/m/venv/lib/python3.12/site-packages/transformers/generation/stopping_criteria.py", line 496, in __call__                  
    is_done = is_done | criteria(input_ids, scores, **kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/m/venv/lib/python3.12/site-packages/transformers/generation/stopping_criteria.py", line 402, in __call__                  
    embedded = F.embedding(flipped_ids, self.embedding_vec)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^   
  File "/home/m/venv/lib/python3.12/site-packages/torch/nn/functional.py", line 2551, in embedding                                      
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)  
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

This is due to mismatch between len(tokenizer) (50277) and model.config.vocab_size (50304 or 50432). This decision to round up the size of the embedding matrix to the next multiple of 128 or 256 was presumably made due to efficiency reasons. However, during sampling, tokens above len(tokenizer) can sometimes be generated. This is silently ignored by the tokenizer, converting such tokens to empty string. However, StopStringCriteria is implemented by indexing into an embedding with size determined by len(tokenizer) and therefore fails when it encounters a higher token.

A temporary fix is to explicitly suppress the unknown tokens from being generated:

if len(tokenizer) < model.config.vocab_size:
    model.generation_config.suppress_tokens = list(range(len(tokenizer), model.config.vocab_size))

I propose that a more principled solution would to be modify StopStringCriteria to ignore tokens above len(tokenizer).

Expected behavior

Expected behavior of the generate method is to not fail.

@Kripner Kripner added the bug label Dec 12, 2024
@Rocketknight1
Copy link
Member

Hi @Kripner - yes, this does look like a bug, although it should only be triggered rarely. I think the simplest solution would be to extend the stop string embedding matrix to model.config.vocab_size, but keep everything else the same. The extra rows will basically be 'null', and so the extra tokens cannot contribute to a stop string match, but at least it won't crash the whole library. What do you think?

@Kripner
Copy link
Contributor Author

Kripner commented Dec 13, 2024

Hi @Kripner - yes, this does look like a bug, although it should only be triggered rarely. I think the simplest solution would be to extend the stop string embedding matrix to model.config.vocab_size, but keep everything else the same. The extra rows will basically be 'null', and so the extra tokens cannot contribute to a stop string match, but at least it won't crash the whole library. What do you think?

Hi @Rocketknight1, this might not be possible because StoppingCriteria does not have access to the model, only to the tokenizer. I think the fix should be in the __call__ method of StopStringCriteria by cropping flipped_ids from the first occurrence of an unknown token.

@Rocketknight1
Copy link
Member

@Kripner Sounds good, but to keep static shapes for better compilation, how about:

  • Make a mask array for flipped_ids >= len(tokenizer)
  • Clip the values in flipped_ids to len(tokenizer) - 1 so there are no more embedding lookup errors
  • Apply the mask after the embedding lookup, and set all the masked positions to -1 so that token matches end at that point

I'm more familiar with XLA, so I don't know how much the torch compiler depends on static shapes, but if there's a static shape solution I think we should use it regardless!

@Kripner
Copy link
Contributor Author

Kripner commented Dec 13, 2024

@Kripner Sounds good, but to keep static shapes for better compilation, how about:

  • Make a mask array for flipped_ids >= len(tokenizer)
  • Clip the values in flipped_ids to len(tokenizer) - 1 so there are no more embedding lookup errors
  • Apply the mask after the embedding lookup, and set all the masked positions to -1 so that token matches end at that point

I'm more familiar with XLA, so I don't know how much the torch compiler depends on static shapes, but if there's a static shape solution I think we should use it regardless!

@Rocketknight1 This looks great to me!

@Rocketknight1
Copy link
Member

@Kripner Cool! Would you be willing to make that PR?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Rocketknight1 pushed a commit that referenced this issue Jan 20, 2025
This fixes #35244 by clipping token IDs to be within the tokenizer's vocabulary size before performing the embedding lookup. This prevents index errors when model.config.vocab_size > len(tokenizer).

The fix:
1. Adds a clamp operation to ensure token IDs are within bounds
2. Adds a test case to verify the behavior
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants