Skip to content

Commit

Permalink
Adding reverse and symmetric KLD losses (#2094)
Browse files Browse the repository at this point in the history
  • Loading branch information
insop authored Jan 30, 2025
1 parent be4ff50 commit 6764618
Show file tree
Hide file tree
Showing 4 changed files with 507 additions and 2 deletions.
209 changes: 208 additions & 1 deletion tests/torchtune/modules/loss/test_kd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
import pytest
import torch
from tests.test_utils import assert_expected
from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss
from torchtune.modules.loss import (
ForwardKLLoss,
ForwardKLWithChunkedOutputLoss,
ReverseKLLoss,
ReverseKLWithChunkedOutputLoss,
SymmetricKLLoss,
SymmetricKLWithChunkedOutputLoss,
)
from torchtune.training.seed import set_seed


Expand Down Expand Up @@ -140,3 +147,203 @@ def test_forward_kl_loss_expected(self):
# assert
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)


class TestReverseKLWithChunkedOutputLoss:
def test_reverse_kl_loss(self):
# Create a sample input and label
ignore_index = -100
batch_size = 3
num_tokens = 50
vocab_size = 50
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16)
teacher_logits = torch.randn(
batch_size, num_tokens, vocab_size, dtype=torch.bfloat16
)
labels = torch.randint(
0, vocab_size, (batch_size, num_tokens), dtype=torch.long
)

# add random ignore index to random tokens in the label
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens))
labels[random_indices < num_tokens // 5] = ignore_index

# chunked RKL
chunked_rkl_loss = ReverseKLWithChunkedOutputLoss(
num_output_chunks=8, ignore_index=ignore_index
)
logits_chunks = logits.chunk(chunked_rkl_loss.num_output_chunks, dim=1)
teacher_logits_chunks = teacher_logits.chunk(
chunked_rkl_loss.num_output_chunks, dim=1
)
chunked_loss = chunked_rkl_loss(logits_chunks, teacher_logits_chunks, labels)

# vanilla RKL
rkl_loss = ReverseKLLoss(ignore_index=ignore_index)
logits = logits.reshape(-1, logits.size(-1))
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1))
labels = labels.reshape(-1)
standard_loss = rkl_loss(logits, teacher_logits, labels)

# Assert
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)

def test_reverse_kl_loss_expected(self):
student_logits = torch.tensor(
[
[
[1.1250, -0.4102, -0.0879, -2.5000],
[0.2676, 0.3535, 0.8711, -1.4688],
[-0.1084, 1.6641, 0.0084, 0.1196],
[0.5000, -0.6406, -0.2236, -1.5938],
],
[
[-1.5312, -1.9219, 0.0000, -0.5039],
[-1.5391, 1.5312, 0.5820, 0.2695],
[-0.3887, 1.2188, 0.0000, 0.6055],
[0.5000, 1.3828, 0.1309, -1.0312],
],
],
dtype=torch.bfloat16,
)
teacher_logits = torch.tensor(
[
[
[-0.0381, -1.2578, -1.2031, 0.0947],
[-0.7852, 0.4492, 1.5547, 0.0972],
[0.8203, 0.0012, 0.7656, 0.3477],
[-1.5781, 0.4297, 0.5977, 0.3926],
],
[
[1.5156, 0.1641, 2.0781, -0.7734],
[-0.5898, 0.4453, -0.7969, 0.6328],
[0.6289, -0.8359, 0.9258, 0.2109],
[0.0006, 0.5195, 3.2344, -1.5781],
],
],
dtype=torch.bfloat16,
)
labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]])
expected_loss = torch.tensor(0.6775, dtype=torch.float32)

# chunked RKL loss
chunked_rkl_loss = ReverseKLWithChunkedOutputLoss(
num_output_chunks=2, ignore_index=-100
)
student_logits_chunks = student_logits.chunk(
chunked_rkl_loss.num_output_chunks, dim=1
)
teacher_logits_chunks = teacher_logits.chunk(
chunked_rkl_loss.num_output_chunks, dim=1
)
chunked_loss = chunked_rkl_loss(
student_logits_chunks, teacher_logits_chunks, labels
)

# vanilla RKL loss
rkl_loss = ReverseKLLoss(ignore_index=-100)
standard_loss = rkl_loss(student_logits, teacher_logits, labels)

# assert
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)


class TestSymmetricKLWithChunkedOutputLoss:
def test_symmetric_kl_loss(self):
# Create a sample input and label
ignore_index = -100
batch_size = 3
num_tokens = 50
vocab_size = 50
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16)
teacher_logits = torch.randn(
batch_size, num_tokens, vocab_size, dtype=torch.bfloat16
)
labels = torch.randint(
0, vocab_size, (batch_size, num_tokens), dtype=torch.long
)

# add random ignore index to random tokens in the label
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens))
labels[random_indices < num_tokens // 5] = ignore_index

# chunked Symmetric KL
chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss(
num_output_chunks=8, ignore_index=ignore_index
)
logits_chunks = logits.chunk(chunked_sym_kl_loss.num_output_chunks, dim=1)
teacher_logits_chunks = teacher_logits.chunk(
chunked_sym_kl_loss.num_output_chunks, dim=1
)
chunked_loss = chunked_sym_kl_loss(logits_chunks, teacher_logits_chunks, labels)

# vanilla Symmetric KL
sym_kl_loss = SymmetricKLLoss(ignore_index=ignore_index)
logits = logits.reshape(-1, logits.size(-1))
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1))
labels = labels.reshape(-1)
standard_loss = sym_kl_loss(logits, teacher_logits, labels)

# Assert
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)

def test_symmetric_kl_loss_expected(self):
student_logits = torch.tensor(
[
[
[1.1250, -0.4102, -0.0879, -2.5000],
[0.2676, 0.3535, 0.8711, -1.4688],
[-0.1084, 1.6641, 0.0084, 0.1196],
[0.5000, -0.6406, -0.2236, -1.5938],
],
[
[-1.5312, -1.9219, 0.0000, -0.5039],
[-1.5391, 1.5312, 0.5820, 0.2695],
[-0.3887, 1.2188, 0.0000, 0.6055],
[0.5000, 1.3828, 0.1309, -1.0312],
],
],
dtype=torch.bfloat16,
)
teacher_logits = torch.tensor(
[
[
[-0.0381, -1.2578, -1.2031, 0.0947],
[-0.7852, 0.4492, 1.5547, 0.0972],
[0.8203, 0.0012, 0.7656, 0.3477],
[-1.5781, 0.4297, 0.5977, 0.3926],
],
[
[1.5156, 0.1641, 2.0781, -0.7734],
[-0.5898, 0.4453, -0.7969, 0.6328],
[0.6289, -0.8359, 0.9258, 0.2109],
[0.0006, 0.5195, 3.2344, -1.5781],
],
],
dtype=torch.bfloat16,
)
labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]])
expected_loss = torch.tensor(1.1992, dtype=torch.float32)

# chunked Symmetric KL loss
chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss(
num_output_chunks=2, ignore_index=-100
)
student_logits_chunks = student_logits.chunk(
chunked_sym_kl_loss.num_output_chunks, dim=1
)
teacher_logits_chunks = teacher_logits.chunk(
chunked_sym_kl_loss.num_output_chunks, dim=1
)
chunked_loss = chunked_sym_kl_loss(
student_logits_chunks, teacher_logits_chunks, labels
)

# vanilla Symmetric KL loss
sym_kl_loss = SymmetricKLLoss(ignore_index=-100)
standard_loss = sym_kl_loss(student_logits, teacher_logits, labels)

# assert
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
13 changes: 12 additions & 1 deletion torchtune/modules/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,21 @@
# LICENSE file in the root directory of this source tree.

from .ce_chunked_output_loss import CEWithChunkedOutputLoss
from .kd_losses import ForwardKLLoss, ForwardKLWithChunkedOutputLoss
from .kd_losses import (
ForwardKLLoss,
ForwardKLWithChunkedOutputLoss,
ReverseKLLoss,
ReverseKLWithChunkedOutputLoss,
SymmetricKLLoss,
SymmetricKLWithChunkedOutputLoss,
)

__all__ = [
"CEWithChunkedOutputLoss",
"ForwardKLLoss",
"ForwardKLWithChunkedOutputLoss",
"ReverseKLLoss",
"ReverseKLWithChunkedOutputLoss",
"SymmetricKLLoss",
"SymmetricKLWithChunkedOutputLoss",
]
Loading

0 comments on commit 6764618

Please sign in to comment.