From cdb78413f529028841586a27c7ec32267714bcc5 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 19 Nov 2024 08:43:56 -0800 Subject: [PATCH 1/4] transformer x graph break --- torchtune/modules/transformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 5b1fb88739..9d64a1228e 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -633,7 +633,9 @@ def forward( for i, layer in enumerate(self.layers): if i in self.output_hidden_states: hidden.append(h) + # shape: [b, s, d] + torch._dynamo.mark_dynamic(h, 1) # avoid graph breaks h = layer( h, mask=mask, From 68f7951a5fa69463c8ec4f38c31a0eebe8c2af93 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 19 Nov 2024 08:44:08 -0800 Subject: [PATCH 2/4] log braph break when packed --- torchtune/modules/attention_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index 8afd4eba71..de0059e52b 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -191,11 +191,12 @@ def _attention_call( # This will use flash attention under the hood with support for custom masks. # Currently, it is used when sample packing is enabled (see torchtune.datasets.PackedDataset) if isinstance(mask, BlockMask): - log_once( - _log, - "Using flex attention for attention computation since a BlockMask was passed in.", - level=logging.DEBUG, - ) + if not torch.compiler.is_compiling(): # avoid graph break + log_once( + _log, + "Using flex attention for attention computation since a BlockMask was passed in.", + level=logging.DEBUG, + ) if dropout_p > 0.0: raise ValueError( "Flex attention does not support dropout. Please set dropout to 0.0." From 34b9ceb846352bc5660f3f8e2d849917d0bdcb16 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 19 Nov 2024 11:23:16 -0800 Subject: [PATCH 3/4] fix --- .../llama3_2_vision/_component_builders.py | 21 +------------------ .../models/llama3_2_vision/_model_builders.py | 2 -- torchtune/modules/peft/dora.py | 2 +- torchtune/modules/peft/lora.py | 2 +- 4 files changed, 3 insertions(+), 24 deletions(-) diff --git a/torchtune/models/llama3_2_vision/_component_builders.py b/torchtune/models/llama3_2_vision/_component_builders.py index 3de323d368..4f3e6403e0 100644 --- a/torchtune/models/llama3_2_vision/_component_builders.py +++ b/torchtune/models/llama3_2_vision/_component_builders.py @@ -338,7 +338,6 @@ def lora_llama3_2_vision_encoder( fusion_lora: bool, lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, - apply_lora_to_output: bool = False, *, # clip encoder parameters patch_size: int, @@ -377,8 +376,6 @@ def lora_llama3_2_vision_encoder( ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False - apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. - Default: False patch_size (int): The size of each patch. Used to divide the tiles into patches. E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. @@ -412,7 +409,6 @@ def lora_llama3_2_vision_encoder( lora_options = { "lora_modules": lora_attn_modules, "apply_lora_to_mlp": apply_lora_to_mlp, - "apply_lora_to_output": apply_lora_to_output, "lora_rank": lora_rank, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout, @@ -679,7 +675,6 @@ def lora_llama3_2_vision_projection_head( num_hidden_inputs: int, # LoRA args apply_lora_to_mlp: bool, - apply_lora_to_output: bool, lora_rank: int, lora_alpha: float, lora_dropout: float = 0.0, @@ -701,8 +696,6 @@ def lora_llama3_2_vision_projection_head( num_hidden_inputs (int): number of hidden inputs to the projection head. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False - apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. - Default: False lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 @@ -773,19 +766,7 @@ def lora_llama3_2_vision_projection_head( # cross encoding # TODO: quantize_base is not applied to final output_proj currently. proj_in = clip_embed_dim * (num_hidden_inputs + 1) - adapter_cls = DoRALinear if use_dora else LoRALinear - output_proj = ( - adapter_cls( - proj_in, - decoder_embed_dim, - rank=lora_rank, - alpha=lora_alpha, - dropout=lora_dropout, - use_bias=True, - ) - if apply_lora_to_output - else nn.Linear(proj_in, decoder_embed_dim) - ) + output_proj = nn.Linear(proj_in, decoder_embed_dim) return Llama3VisionProjectionHead( layers=layers, output=output_proj, diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index d13ff2dcc4..91e54781af 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -172,7 +172,6 @@ def lora_llama3_2_vision_11b( fusion_lora=fusion_type == LoRATrainable.LORA, lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, - apply_lora_to_output=apply_lora_to_output, patch_size=14, num_heads=16, clip_embed_dim=1280, @@ -330,7 +329,6 @@ def lora_llama3_2_vision_90b( fusion_lora=fusion_type == LoRATrainable.LORA, lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, - apply_lora_to_output=apply_lora_to_output, patch_size=14, num_heads=16, clip_embed_dim=1280, diff --git a/torchtune/modules/peft/dora.py b/torchtune/modules/peft/dora.py index bc1e5eeb03..6f097da6d0 100644 --- a/torchtune/modules/peft/dora.py +++ b/torchtune/modules/peft/dora.py @@ -65,7 +65,7 @@ def __init__( self.use_bias = use_bias self._quantize_base = quantize_base - if not self._quantize_base and quantization_kwargs: + if not self._quantize_base and any([v for v in quantization_kwargs.values()]): raise ValueError( f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}" ) diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 138dd0c5ee..e03d854f1f 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -65,7 +65,7 @@ def __init__( self.use_bias = use_bias self._quantize_base = quantize_base - if not self._quantize_base and quantization_kwargs: + if not self._quantize_base and any([v for v in quantization_kwargs.values()]): raise ValueError( f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}" ) From b41114ad3631c8409c8c2409755e70acbf66e016 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 19 Nov 2024 14:31:28 -0800 Subject: [PATCH 4/4] fix graph breaks --- recipes/lora_finetune_single_device.py | 8 +++++++- torchtune/modules/loss/ce_chunked_output_loss.py | 6 ++++++ torchtune/modules/transformer.py | 11 ++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index fcdb3e4ea5..467ad53fc1 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -664,6 +664,7 @@ def train(self) -> None: ) # Initialize tokens count and running loss (for grad accumulation) + start = time.perf_counter() t0 = time.perf_counter() running_loss = 0 num_tokens = 0 @@ -730,6 +731,7 @@ def train(self) -> None: # Log per-step metrics if self.global_step % self._log_every_n_steps == 0: time_per_step = time.perf_counter() - t0 + print(time_per_step) log_dict = { "loss": loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], @@ -773,13 +775,17 @@ def train(self) -> None: self.epochs_run += 1 start_save_checkpoint = time.perf_counter() log.info("Starting checkpoint save...") - self.save_checkpoint(epoch=curr_epoch) + # self.save_checkpoint(epoch=curr_epoch) log.info( "Checkpoint saved in {:.2f} seconds.".format( time.perf_counter() - start_save_checkpoint ) ) + end = time.perf_counter() + time_total = end - start + print(f"{time_total=}") + def cleanup(self) -> None: self._metric_logger.close() diff --git a/torchtune/modules/loss/ce_chunked_output_loss.py b/torchtune/modules/loss/ce_chunked_output_loss.py index 17a5eced36..ff1525758a 100644 --- a/torchtune/modules/loss/ce_chunked_output_loss.py +++ b/torchtune/modules/loss/ce_chunked_output_loss.py @@ -78,6 +78,12 @@ def forward(self, logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Ten # compute one chunk at a time total_loss = 0.0 for logits_chunk, labels_chunk in zip(logits, labels): + + # avoid graph breaks when seq_len is not constant in the batch + torch._dynamo.mark_dynamic(logits_chunk, 0) + torch._dynamo.mark_dynamic(labels_chunk, 0) + + # CE total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk) return total_loss / total_elements diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 9d64a1228e..23aa79e050 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -628,14 +628,23 @@ def forward( # shape: [b, s, d] h = self.tok_embeddings(tokens) + h.requires_grad = True # avoid graph breaks when using LoRA hidden = [] for i, layer in enumerate(self.layers): if i in self.output_hidden_states: hidden.append(h) + # avoid graph breaks when seq_len is not constant in the batch + torch._dynamo.mark_dynamic(h, 1) + if mask is not None: + torch._dynamo.mark_dynamic(mask, 1) + if encoder_mask is not None: + torch._dynamo.mark_dynamic(encoder_mask, 1) + if input_pos is not None: + torch._dynamo.mark_dynamic(input_pos, 1) + # shape: [b, s, d] - torch._dynamo.mark_dynamic(h, 1) # avoid graph breaks h = layer( h, mask=mask,