Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

feat: add soft distillation #736

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mattmazzola
Copy link
Contributor

⚠️ This PR is not intended to be merged directly. Purpose to share features that may be useful for Metaseq ⚠️

Background

One of the main goals for our project's fork was to implement "soft" distillation (training on set of logprobs rather than correctness of token class) and to measure the efficacy of this technique compared to normal finetuning

From our docs:

The motivation for training on log probabilities rather than token classes is to pass as much knowledge from the teacher to the student as possible. [... By the teacher providing] log probabilities of other tokens in the vocabulary [we expect] the student better learn to represent the teacher’s knowledge.

Issue

  • Soft Distillation was not implemented

Solution

  • Add new pipeline task streaming_distillation_language_modeling
    • Add new criterion vocab_parallel_soft_cross_entropy (Note: Soft)
      • Considers multiple possible predictions for each token of the target sequence
    • Adds new parameters
      --task streaming_distillation_language_modeling
      --distillation-mode logprobs_distillation
      --criterion vocab_parallel_soft_cross_entropy

Testing

Did not test

Related to #726

This feature was implemented by @anselmwang and @clarissesimoes

Copy link

@clarissesimoes clarissesimoes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one the most valuable things that we've added is the documentation, as both parallel vocabulary cross entropy and soft cross entropy can be hard to understand due to the fact of both of them applying simplifications in the formulas, all of them explained in the docs. Is that a way to bring the MD files that we've created to this diff?

@@ -59,10 +60,10 @@ def log_weight_stats(tensor, name):
)


class ModelParallelTransformerDecoder(BaseDecoder):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall changing this class in our implementation, but this might have been changed by Yu or Sahaj. It would be a good idea to include Sahaj in this review too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can modify reviewers since we are not contributors/admins on repo.
However, I think we can mention them here @anselmwang @sahajgg

Comment on lines +543 to +548
# Gather output if model is in inference mode (i.e. evallm or generation) cause both are not yet compatible with
# parallel vocab embeddings
criterion = getattr(self.args, "criterion")
is_parallel_criterion = criterion.find("vocab_parallel") != -1
if not is_parallel_criterion or getattr(self, "inference", False):
x = gather_from_tensor_model_parallel_region(x).contiguous()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only change that I confirm that I've made in this class. Please confirm if the others should be double checked with Sahaj and Yu. My concern is that we might be reverting changes important in the original code in the rest of the diff

Comment on lines +40 to +66
target_mask = (target_tokens < vocab_start_index) | (target_tokens >= vocab_end_index)
masked_target = target_tokens.clone() - vocab_start_index
masked_target[target_mask] = 0

# Get predicted-logits = logits[top_logprobs].
predicted_logits = vocab_parallel_logits.gather(dim=-1, index=masked_target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(
predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()
)

# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(
sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()
)

# Loss = log(sum(exp(logits))) - predicted-logit.
target_weights = target_predictions.exp()
loss = ((torch.log(sum_exp_logits).unsqueeze(dim=-1) - predicted_logits) * target_weights).sum(-1)

# Store softmax, top_logprobs-mask and masked-top_logprobs for backward pass.
softmax = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(softmax, target_mask, masked_target, target_weights)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember, the majority of work was in this section of code. @clarissesimoes can you confirm?
If not, can you make a comment to call out other notable places of code they should pay closer attention to

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, there is a documentation file included in PR to help explain this code.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm, and I'd also mention the file metaseq/tasks/streaming_distillation_language_modeling.py as equally important, as data preprocessing and masks are slightly different for distillation when compared to finetuning

has_megatron_submodule = False


class _VocabParallelMSELoss(torch.autograd.Function):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I remember, we implemented this thinking it would work for Logits AND Logprobs, but it only works for Logits. Then because we could only get logprobs from OpenAI model output and couldn't convert to logits this loss effectively became unused.

However, we left it in because the implementation may be valuable for other applications where logit values are available

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've tested MSE Loss with logprobs but training never converged. MSE Loss can be used if the input is teacher logits, though

@@ -59,10 +60,10 @@ def log_weight_stats(tensor, name):
)


class ModelParallelTransformerDecoder(BaseDecoder):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can modify reviewers since we are not contributors/admins on repo.
However, I think we can mention them here @anselmwang @sahajgg

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants