Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed + CachedMNRL: Gradient computed twice for this partition. Multiple gradient reduction is currently not supported #3173

Open
Hypothesis-Z opened this issue Jan 15, 2025 · 2 comments

Comments

@Hypothesis-Z
Copy link

Hypothesis-Z commented Jan 15, 2025

My script

import datasets
from sentence_transformers import (
    SentenceTransformer, 
    SentenceTransformerTrainer, 
    SentenceTransformerTrainingArguments, 
)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers import models
from sentence_transformers.loss import MatryoshkaLoss, CachedMultipleNegativesRankingLoss

if __name__ == '__main__':
    training_args = SentenceTransformerTrainingArguments(
        output_dir="test",
        num_train_epochs=1,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        eval_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=2,
        logging_steps=100,
        run_name="pair-score", 
        gradient_checkpointing=True,
        deepspeed='ds_config.json',
        gradient_accumulation_steps=64,
    )
    word_embedding_model = models.Transformer("AI-ModelScope/Mistral-7B-v0.2-hf")
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'lasttoken')
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    model.tokenizer.pad_token = model.tokenizer.eos_token
    model.tokenizer.add_eos_token = True
    model.tokenizer.padding_side = 'right'

    train_dataset = datasets.load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train[:10000]")
    loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=32)

    trainer = SentenceTransformerTrainer(
        model=model,
        train_dataset=train_dataset,
        loss=loss,
        args=training_args,
    )

    trainer.train()

My Deepspeed config

{
  "optimizer": {
    "type": "Adam",
    "params": {
      "bias_correction": true
    }
  },
  "zero_optimization": {
    "stage": 2,
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    }
  },
  "gradient_accumulation_steps": "auto",
  "train_micro_batch_size_per_gpu": 8,
  "bf16": {
    "enable": true
  }
}

Error

The parameter 289 has already been reduced. Gradient computed twice for this partition. Multiple gradient reduction is currently not supported

The exception is raised from deepspeed the secomd time surrogate.backward() is called in CachedMultipleNegativesRankingLoss.

def _backward_hook(
    grad_output: Tensor,
    sentence_features: Iterable[dict[str, Tensor]],
    loss_obj: CachedMultipleNegativesRankingLoss,
) -> None:
    """A backward hook to backpropagate the cached gradients mini-batch by mini-batch."""
    assert loss_obj.cache is not None
    assert loss_obj.random_states is not None
    with torch.enable_grad():
        for sentence_feature, grad, random_states in zip(sentence_features, loss_obj.cache, loss_obj.random_states):
            for (reps_mb, _), grad_mb in zip(
                loss_obj.embed_minibatch_iter(
                    sentence_feature=sentence_feature,
                    with_grad=True,
                    copy_random_state=False,
                    random_states=random_states,
                ),
                grad,
            ):
                surrogate = torch.dot(reps_mb.flatten(), grad_mb.flatten()) * grad_output
                surrogate.backward() # exception raised here
@Hypothesis-Z Hypothesis-Z changed the title Deepspeed + CachedNMRL: Gradient computed twice for this partition. Multiple gradient reduction is currently not supported Deepspeed + CachedMNRL: Gradient computed twice for this partition. Multiple gradient reduction is currently not supported Jan 15, 2025
@tomaarsen
Copy link
Collaborator

Hello!

Hmm, that is bothersome. Indeed, CachedMNRL computes gradients twice. It's based on GradCache where the same happens. I see that DeepSpeed is not expecting that, so there's an incompatibility there. I don't think there is a simple workaround for that at the moment, other than 1) not using DeepSpeed or 2) not using a Cached loss.

There's some small extra details on the GradCache project itself: luyug/GradCache#11

  • Tom Aarsen

@Hypothesis-Z
Copy link
Author

Hypothesis-Z commented Jan 17, 2025

@tomaarsen Thank you! I've upgraded accelerate/deepspeed/transforms to latest verison and set zero stage to 1. I'm not sure whether the model is well trained but it works, even with MatryoshkaLoss (from main branch)...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants