-
Notifications
You must be signed in to change notification settings - Fork 728
feat: add soft distillation #736
base: main
Are you sure you want to change the base?
feat: add soft distillation #736
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one the most valuable things that we've added is the documentation, as both parallel vocabulary cross entropy and soft cross entropy can be hard to understand due to the fact of both of them applying simplifications in the formulas, all of them explained in the docs. Is that a way to bring the MD files that we've created to this diff?
@@ -59,10 +60,10 @@ def log_weight_stats(tensor, name): | |||
) | |||
|
|||
|
|||
class ModelParallelTransformerDecoder(BaseDecoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't recall changing this class in our implementation, but this might have been changed by Yu or Sahaj. It would be a good idea to include Sahaj in this review too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can modify reviewers since we are not contributors/admins on repo.
However, I think we can mention them here @anselmwang @sahajgg
# Gather output if model is in inference mode (i.e. evallm or generation) cause both are not yet compatible with | ||
# parallel vocab embeddings | ||
criterion = getattr(self.args, "criterion") | ||
is_parallel_criterion = criterion.find("vocab_parallel") != -1 | ||
if not is_parallel_criterion or getattr(self, "inference", False): | ||
x = gather_from_tensor_model_parallel_region(x).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only change that I confirm that I've made in this class. Please confirm if the others should be double checked with Sahaj and Yu. My concern is that we might be reverting changes important in the original code in the rest of the diff
target_mask = (target_tokens < vocab_start_index) | (target_tokens >= vocab_end_index) | ||
masked_target = target_tokens.clone() - vocab_start_index | ||
masked_target[target_mask] = 0 | ||
|
||
# Get predicted-logits = logits[top_logprobs]. | ||
predicted_logits = vocab_parallel_logits.gather(dim=-1, index=masked_target) | ||
predicted_logits[target_mask] = 0.0 | ||
# All reduce is needed to get the chunks from other GPUs. | ||
torch.distributed.all_reduce( | ||
predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() | ||
) | ||
|
||
# Sum of exponential of logits along vocab dimension across all GPUs. | ||
exp_logits = vocab_parallel_logits | ||
torch.exp(vocab_parallel_logits, out=exp_logits) | ||
sum_exp_logits = exp_logits.sum(dim=-1) | ||
torch.distributed.all_reduce( | ||
sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() | ||
) | ||
|
||
# Loss = log(sum(exp(logits))) - predicted-logit. | ||
target_weights = target_predictions.exp() | ||
loss = ((torch.log(sum_exp_logits).unsqueeze(dim=-1) - predicted_logits) * target_weights).sum(-1) | ||
|
||
# Store softmax, top_logprobs-mask and masked-top_logprobs for backward pass. | ||
softmax = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)) | ||
ctx.save_for_backward(softmax, target_mask, masked_target, target_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember, the majority of work was in this section of code. @clarissesimoes can you confirm?
If not, can you make a comment to call out other notable places of code they should pay closer attention to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, there is a documentation file included in PR to help explain this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confirm, and I'd also mention the file metaseq/tasks/streaming_distillation_language_modeling.py as equally important, as data preprocessing and masks are slightly different for distillation when compared to finetuning
has_megatron_submodule = False | ||
|
||
|
||
class _VocabParallelMSELoss(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I remember, we implemented this thinking it would work for Logits AND Logprobs, but it only works for Logits. Then because we could only get logprobs from OpenAI model output and couldn't convert to logits this loss effectively became unused.
However, we left it in because the implementation may be valuable for other applications where logit values are available
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've tested MSE Loss with logprobs but training never converged. MSE Loss can be used if the input is teacher logits, though
@@ -59,10 +60,10 @@ def log_weight_stats(tensor, name): | |||
) | |||
|
|||
|
|||
class ModelParallelTransformerDecoder(BaseDecoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can modify reviewers since we are not contributors/admins on repo.
However, I think we can mention them here @anselmwang @sahajgg
Background
One of the main goals for our project's fork was to implement "soft" distillation (training on set of logprobs rather than correctness of token class) and to measure the efficacy of this technique compared to normal finetuning
From our docs:
Issue
Solution
streaming_distillation_language_modeling
vocab_parallel_soft_cross_entropy
(Note: Soft)--task streaming_distillation_language_modeling
--distillation-mode logprobs_distillation
--criterion vocab_parallel_soft_cross_entropy
Testing
Did not test
Related to #726
This feature was implemented by @anselmwang and @clarissesimoes