Skip to content

Commit

Permalink
⛰️ Reduce peak vram consumption with efficient selective log_softmax (#…
Browse files Browse the repository at this point in the history
…2799)

* Reduce mem consumption across many trainers with efficient selective log-softmax approach

* rename

* typo fix

* precommit

* Update tests/test_core.py

* relocate

* precommit

* style

* smaller values for test, and run on cpu

* nit doc improvements

* style

* fix test

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
tyler-romero and qgallouedec authored Feb 7, 2025
1 parent 7fdb69a commit 09eefa7
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 50 deletions.
23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import torch
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
Expand All @@ -32,6 +33,7 @@
generate_model_card,
get_peft_config,
pad,
selective_log_softmax,
)


Expand Down Expand Up @@ -506,3 +508,24 @@ def test_batch_accuracy(self):
)
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.8)


class TestSelectiveLogSoftmax(unittest.TestCase):
@parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)])
def test_selective_log_softmax(self, dtype):
"""Test selective_log_softmax with logits of different dtypes"""
vocab_size = 1024
batch_size = 4
seq_len = 32

input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype)

expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
actual_output = selective_log_softmax(logits, input_ids)

if dtype in [torch.float16, torch.bfloat16]:
# half-precision dtypes fall back to an exact method
self.assertTrue(torch.equal(actual_output, expected_output))
else:
torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
11 changes: 7 additions & 4 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -897,9 +898,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1062,7 +1065,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -711,7 +712,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
11 changes: 7 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
pad,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -822,9 +823,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1211,7 +1214,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to

# Compute the log probabilities of the labels
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)
per_token_logps[~loss_mask] = 0
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)

Expand Down
9 changes: 2 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url, pad
from .utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax


if is_peft_available():
Expand Down Expand Up @@ -442,12 +442,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]

# Compute the log probabilities for the input tokens.
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # loop to reduce memory peak
token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits)))
return token_log_probs
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
Expand Down
11 changes: 7 additions & 4 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -812,9 +813,11 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1032,7 +1035,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
generate_model_card,
get_comet_experiment_url,
get_reward,
selective_log_softmax,
truncate_right,
)

Expand Down Expand Up @@ -277,8 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length):
def compute_logprobs_for_data(m, data):
output = m(data["input_ids"], attention_mask=data["attention_mask"])
logits = output.logits[:, context_length - 1 : -1]
logprobs = F.log_softmax(logits, dim=-1)
token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1)
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
return token_logprobs

# Compute logprobs for model completions under the model
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -718,7 +719,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels = torch.where(labels == label_pad_token_id, 0, labels)

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
25 changes: 12 additions & 13 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
Expand Down Expand Up @@ -65,6 +64,7 @@
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
selective_log_softmax,
truncate_response,
)

Expand Down Expand Up @@ -310,9 +310,11 @@ def get_eval_dataloader(self) -> DataLoader:
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model.policy
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.policy.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -427,9 +429,8 @@ def repeat_generator():
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()

if ref_policy is None:
Expand All @@ -439,9 +440,8 @@ def repeat_generator():
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
Expand Down Expand Up @@ -547,8 +547,7 @@ def repeat_generator():
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
Expand Down Expand Up @@ -599,7 +598,7 @@ def repeat_generator():
# del everything and empty cache
# fmt: off
del (
output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped,
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
Expand Down
20 changes: 8 additions & 12 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
Expand Down Expand Up @@ -56,6 +55,7 @@
get_reward,
prepare_deepspeed,
print_rich_table,
selective_log_softmax,
truncate_response,
)
from .rloo_config import RLOOConfig
Expand Down Expand Up @@ -330,17 +330,15 @@ def repeat_generator():
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()

ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
Expand Down Expand Up @@ -467,8 +465,7 @@ def repeat_generator():
logits /= args.temperature + 1e-7

# Compute new logprobs
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
Expand Down Expand Up @@ -512,9 +509,8 @@ def repeat_generator():
# del everything and empty cache
# fmt: off
del (
output, logits, new_all_logprobs, new_logprobs,
logprobs_diff, ratio, pg_losses, pg_losses2,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
Expand Down
36 changes: 36 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.utils.data
from accelerate import Accelerator, PartialState
from accelerate.state import AcceleratorState
Expand Down Expand Up @@ -1668,3 +1669,38 @@ def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_in
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0

return accuracy


def selective_log_softmax(logits, index):
"""
A memory-efficient implementation of the common `log_softmax -> gather` operation.
This function is equivalent to the following naive implementation:
```python
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
```
Args:
logits (`torch.Tensor`):
Logits tensor of shape `(..., num_classes)`.
index (`torch.Tensor`):
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
Returns:
`torch.Tensor`:
Gathered log probabilities with the same shape as `index`.
"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
per_token_logps = []
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
Loading

0 comments on commit 09eefa7

Please sign in to comment.