From d64309525804e3a643a286764ed43256bca82124 Mon Sep 17 00:00:00 2001 From: kylesayrs Date: Wed, 2 Oct 2024 19:29:45 +0000 Subject: [PATCH 001/285] make max the default --- src/llmcompressor/transformers/compression/helpers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 845f04c4a..e5f7b6588 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -205,7 +205,7 @@ def custom_offload_device_map( def calculate_offload_device_map( model_stub: str, reserve_for_hessians=False, - num_gpus: int = 1, + num_gpus: Optional[int] = None, torch_dtype: torch.dtype = torch.float16, **model_kwargs, ) -> Dict[Union[int, str], Union[int, str]]: @@ -215,14 +215,16 @@ def calculate_offload_device_map( :param model_stub: local path or HF stub to calculate mapping for :param reserve_for_hessians: whether to reserve memory for GPTQ - :param num_gpus: number of gpus to utilize + :param num_gpus: number of gpus to utilize, defaults to max available :param model_kwargs: keyword arguments to pass to model initializer :return: memory mapping for layers of model_stub to be passed to from_pretrained() """ max_cpu_memory = psutil.virtual_memory().available max_gpu_memory = torch.cuda.mem_get_info(0)[0] available_gpus = torch.cuda.device_count() - if available_gpus < num_gpus: + if num_gpus is None: + num_gpus = available_gpus + elif num_gpus >= available_gpus: raise ValueError( f"Requested {num_gpus} GPUs but only {available_gpus} are available." ) From 536b6864c645cec576ad37cbdaec4416cc7bd1e5 Mon Sep 17 00:00:00 2001 From: kylesayrs Date: Wed, 2 Oct 2024 19:59:57 +0000 Subject: [PATCH 002/285] set default for num_gpus, add gpu_ids argument --- .../transformers/compression/helpers.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index e5f7b6588..3c6d49c7a 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -206,6 +206,7 @@ def calculate_offload_device_map( model_stub: str, reserve_for_hessians=False, num_gpus: Optional[int] = None, + gpu_ids: Optional[List[int]] = None, torch_dtype: torch.dtype = torch.float16, **model_kwargs, ) -> Dict[Union[int, str], Union[int, str]]: @@ -214,21 +215,31 @@ def calculate_offload_device_map( into account extra memory required for quantization and (optionally) GPTQ hessians :param model_stub: local path or HF stub to calculate mapping for - :param reserve_for_hessians: whether to reserve memory for GPTQ + :param reserve_for_hessians: whether to reserve memory for GPTQ/OBCQ :param num_gpus: number of gpus to utilize, defaults to max available + :param gpu_ids: list of gpu device ids to utilize, overrides num_gpus if provided + :param torch_dtype: datatype in which model weights are to be loaded with :param model_kwargs: keyword arguments to pass to model initializer :return: memory mapping for layers of model_stub to be passed to from_pretrained() """ max_cpu_memory = psutil.virtual_memory().available - max_gpu_memory = torch.cuda.mem_get_info(0)[0] available_gpus = torch.cuda.device_count() - if num_gpus is None: - num_gpus = available_gpus - elif num_gpus >= available_gpus: + if gpu_ids is None: + if num_gpus is None: + num_gpus = available_gpus + gpu_ids = range(num_gpus) + else: + num_gpus = len(gpu_ids) + + if num_gpus > available_gpus: raise ValueError( f"Requested {num_gpus} GPUs but only {available_gpus} are available." ) - max_gpu_memory = [max_gpu_memory] * num_gpus + + max_gpu_memory = { + device_id: torch.cuda.mem_get_info(device_id)[0] + for device_id in gpu_ids + } device_map = {} with init_empty_weights(): @@ -243,7 +254,7 @@ def calculate_offload_device_map( memory_limits = { idx: (max_memory - reserved_memory) - for idx, max_memory in enumerate(max_gpu_memory) + for idx, max_memory in max_gpu_memory.items() } memory_limits["cpu"] = max_cpu_memory From 21b6a7cc631eb8f1a1944c364713bf7d0be011d4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 2 Oct 2024 16:09:42 -0400 Subject: [PATCH 003/285] update documentation --- src/llmcompressor/transformers/compression/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 3c6d49c7a..7d4428fb1 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -113,7 +113,7 @@ def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str] def hessian_memory_requirements(model: torch.nn.Module) -> int: """ Determines the number of bytes needed to store Hessian data for a single - transformer layer in model. This is used for reserving memory for GPTQ + transformer layer in model. This is used for reserving memory for GPTQ/OBCQ quantization :param model: model to calculate requirements for From fc87f7913ffd9762f391df03202a7e4f5b9bd3a2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Oct 2024 19:32:55 -0400 Subject: [PATCH 004/285] remove unnecessary lines --- src/llmcompressor/transformers/compression/helpers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 7d4428fb1..b466b4640 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -241,7 +241,6 @@ def calculate_offload_device_map( for device_id in gpu_ids } - device_map = {} with init_empty_weights(): dummy_model = AutoModelForCausalLM.from_pretrained( model_stub, torch_dtype=torch_dtype, **model_kwargs @@ -263,7 +262,6 @@ def calculate_offload_device_map( max_memory=memory_limits, no_split_module_classes=dummy_model._no_split_modules, ) - del dummy_model return device_map From 7f641445e29861a2d3e68d97cc1ce280fa749588 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Oct 2024 19:34:09 -0400 Subject: [PATCH 005/285] Remove unnecessary lines --- src/llmcompressor/transformers/compression/helpers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index b466b4640..a505dc0af 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -189,7 +189,6 @@ def custom_offload_device_map( memory_limits = {device: max_memory_per_gpu for device in range(num_gpus)} memory_limits["cpu"] = max_cpu_memory - device_map = {} with init_empty_weights(): dummy_model = AutoModelForCausalLM.from_pretrained(model_stub, **model_kwargs) device_map = infer_auto_device_map( @@ -197,7 +196,6 @@ def custom_offload_device_map( max_memory=memory_limits, no_split_module_classes=dummy_model._no_split_modules, ) - del dummy_model return device_map From 98b284b9b1213e9749df7d5b95323dd92cd2f98a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 04:03:55 +0000 Subject: [PATCH 006/285] WIP --- examples/quantization_w4a16/llama3_example.py | 1 + .../modifiers/quantization/gptq/base.py | 235 ++++++--------- .../quantization/gptq/utils/compress.py | 278 ++++++++++++++++++ src/llmcompressor/utils/helpers.py | 45 +++ 4 files changed, 419 insertions(+), 140 deletions(-) create mode 100644 src/llmcompressor/modifiers/quantization/gptq/utils/compress.py diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 939991ab6..d587a6199 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -6,6 +6,7 @@ # Select model and load it. MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +#MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b472e289e..fd1c42116 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch +from functools import partial from compressed_tensors.quantization import ( QuantizationScheme, disable_quantization, @@ -21,6 +22,7 @@ from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.fsdp.context import fix_fsdp_module_name +from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, getattr_chain from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -109,11 +111,6 @@ class GPTQModifier(Modifier): num_calibration_steps: Optional[int] = None scheme: Optional[Union[str, Dict[str, Any]]] = None - model: Optional[Any] = None - layer_compressors_: Optional[List[Any]] = None - compressible_layers_: Optional[List] = None - quantization_modifier_: Any = None - @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: if not value: @@ -124,6 +121,13 @@ def validate_sequential_update(cls, value: bool) -> bool: ) return value + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.current_layer_index = 0 + self.num_layers = 0 + self.quantization_modifier_ = None def on_initialize_structure(self, state: State, **kwargs): """ @@ -191,20 +195,29 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - modifiable_model = state.model - calibration_dataloader = state.data.calib - + # find layers (used for printing even if true_sequential=True) + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) if self.sequential_targets is None: - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - self.sequential_targets = get_no_split_params(modifiable_model) + self.sequential_targets = get_no_split_params(state.model) + layers = get_layers(self.sequential_targets, state.model) + self.num_layers = len(layers) + + # add hooks to targets and layers + self.register_hooks(state.model, layers) + + # apply calibration and trigger hooks (hooks are self removing) + self.calibration_forward(state.model, state.data.calib) - self.initialize_compression(modifiable_model, calibration_dataloader) - self.apply_compression(calibration_dataloader) + # freeze quantization state.model.apply(freeze_module_quantization) return True + + def on_end(self): + self.register_hooks(state.model, layers) + self.dummy_forward() ??? def on_finalize(self, state: "State", **kwargs) -> bool: """ @@ -216,121 +229,80 @@ def on_finalize(self, state: "State", **kwargs) -> bool: self.quantization_modifier_.finalize(state, **kwargs) return True - - def compressible_layers(self) -> Dict: - """ - Retrieves the modules corresponding to a list of - compressible layer names - - :precondition: self.model is set and is a torch.nn.Module - :return: dictionary of modules to compress - """ - if not isinstance(self.model, Module): - raise ValueError( - "`self.model` must be a torch.nn.Module to use " - f"the {self.__class__.__qualname__} modifier but got " - f"{type(self.model)} instead" + + def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): + layers = layers.values() + + for name, module in model.named_modules(): + quant_args = getattr_chain(module, "quantization_scheme.weights", None) + if quant_args is not None: + module._gptq_pre_hook = module.register_forward_pre_hook( + partial(self.target_pre_forward, name, quant_args)) + module._gptq_post_hook = module.register_forward_hook( + partial(self.target_post_forward, name, quant_args)) + + if module in layers.values(): + module._gptq_pre_hook = module.register_forward_pre_hook( + partial(self.layer_pre_forward, name)) + module._gptq_post_hook = module.register_forward_hook( + partial(self.layer_post_forward, name)) + + def calibration_forward(self, model: torch.nn.Module, data: torch.utils.data.Dataloader): + all_data = torch.cat([batch for batch in data], dim=0) + with DisableKVCache(model), DisableQuantization(model): + model(all_data) + + def target_pre_forward(self, name: str, quant_args: QuantizationScheme, module: torch.nn.Module, args, kwargs): + if self.true_sequential: + # compress first so output is from quantized weights + logger.info(f"Compressing {name}...") + gptq_compress( + module, + args, + kwargs, + quant_args, + block_size=self.block_size, + percdamp=self.dampening_frac, ) + + def target_post_forward(self, name: str, quant_args: QuantizationScheme, module: torch.nn.Module, args, kwargs, output): + if not self.true_sequential: + # compress after so output is from unquantized weights + logger.info(f"Compressing {name}...") + gptq_compress( + module, + args, + kwargs, + quant_args, + block_size=self.block_size, + percdamp=self.dampening_frac, + ) + + def layer_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): + logger.info(f"\n===== Compressing layer {self.layer_index}/{self.num_layers} =====") + + def layer_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): + self.remove_hooks(module) - return get_layers(self.sequential_targets, self.model) - - def initialize_compression( - self, - model: Module, - dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, - ): - """ - Setup for GPTQ, initializes the model - and other parameters, also initilializes the - compressible layers of model, and sets the device - - :param model: model to initialize for compression - :param dataloader: calibration data for GPTQ - """ - self.model = model - self.compressible_layers_ = self.compressible_layers() - self.layer_compressors_ = [] - - for idx, (name, layer) in enumerate(self.compressible_layers_.items()): - name = fix_fsdp_module_name(name) - logger.info(f"Preparing {name} for compression") - args = self._pruning_arguments() - comp_cls = self._compression_class() - compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) - - # if running sequentially, allocate all hessians now - if not self.sequential_update: - compressor.pre_compress() - - self.layer_compressors_.append(compressor) - - if self.sequential_update: - first_layer_compressor = self.layer_compressors_[0] - first_layer_compressor.set_early_stop() - - @torch.no_grad() - def apply_compression( - self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None - ) -> Dict: - """ - Run GPTQ on the loaded model, using dataloader as calibration data - - :param dataloader: calibration data for GPTQ - """ - class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info( - f"Running {class_name} calibration with " f"{len(dataloader)} samples..." - ) - - # quantization scales and zp are already initialized but we do not - # want to calibrate wrt to these - self.model.apply(disable_quantization) - - forward_pass_use_cache = self.model.config.use_cache - self.model.config.use_cache = False - - # in non-sequential mode we run calibration through the full model - # in sequential mode we run calibration up to the first transformer target - intermediates = run_calibration_forward( - self.model, dataloader, mask_padding=True - ) - self.layer_compressors_[0].clear_early_stop() - - # empty cache if not using sequential update - if not self.sequential_update: - del intermediates - gc.collect() - torch.cuda.empty_cache() - - num_layers = len(self.compressible_layers_) - for idx, layer_compressor in enumerate(self.layer_compressors_): - logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====") - - if self.sequential_update: - # in sequential mode we run the forward pass for each transformer layer - # one at a time, caching the intermediate outputs between layers - logger.info(f"Calibrating {layer_compressor.name}...") - layer_compressor.pre_compress() - unquantized_outputs = layer_compressor.calibrate_layer(intermediates) - - layer_compressor.compress() - layer_compressor.post_compress() - layer_compressor.revert_layer_wrappers() + if not self.true_sequential: + # rerun with (now) quantized weights + output = module(*args, **kwargs) - if self.sequential_update: - quantized_outputs = layer_compressor.calibrate_layer(intermediates) - error = get_output_error(unquantized_outputs, quantized_outputs) - logger.info(f"Mean output error from quantization: {error:.3f}") - intermediates = quantized_outputs - del unquantized_outputs + self.layer_index += 1 + return output - gc.collect() - torch.cuda.empty_cache() + def remove_hooks(self, module: torch.nn.Module, recurse: bool = True): + if hasattr(module, "_gptq_pre_hook"): + module._gptq_pre_hook.remove() + delattr(module, "_gptq_pre_hook") - self.model.config.use_cache = forward_pass_use_cache + if hasattr(module, "_gptq_post_hook"): + module._gptq_post_hook.remove() + delattr(module, "_gptq_post_hook") - # re-enable quantization - self.model.apply(enable_quantization) + if recurse: + for child_module in module.children(): + self.remove_hooks(child_module) def _build_quant_modifier(self): """ @@ -369,20 +341,3 @@ def _build_quant_modifier_from_dict(self, quant_config): allow_experimental=True, **modifier_args, ) - - def _pruning_arguments(self): - """ - Gather the parameters needed for root module compression in a dict - - :return: dict of params for pruning - """ - return { - "blocksize": self.block_size, - "percdamp": self.dampening_frac, - } - - def _compression_class(self): - """ - :return: wrapper class used for root modules of this compression class - """ - return GPTQWrapper diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py b/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py new file mode 100644 index 000000000..05a856072 --- /dev/null +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py @@ -0,0 +1,278 @@ +import torch + + def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + """ + Add a batch of layer input and output data to the Hessian calculation + + :param inp: tensor containing layer input + :param out: tensor containing layer output + """ + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = inp.to(dtype=self.H.dtype) + inp = math.sqrt(2 / self.nsamples) * inp + self.H += inp.matmul(inp.t()) + + def compress( + self, + blocksize: int = 128, + percdamp: float = 0.01, + ): + """ + Run pruning and quantization(if applicable) on the layer up to the target + sparsity value. + + :param blocksize: Number of columns to compress in one pass + :param percdamp: Amount of dampening to apply to H, as a fraction of the + diagonal norm + """ + args_loc = "quantization_scheme.weights" + weight_quant_args = getattr_chain(self.layer, args_loc, None) + if weight_quant_args is None: + logger.debug(f"Skipping unquantized layer {self.name}...") + return + + if is_module_offloaded(self.layer): + self.layer._hf_hook.pre_forward(self.layer) + + strategy = weight_quant_args.strategy + actorder = weight_quant_args.actorder + final_shape = self.layer.weight.shape + final_dtype = self.layer.weight.dtype + W = self.layer.weight.data.clone() + + # standardize shape and dtype + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + elif isinstance(self.layer, transformers.Conv1D): + W.transpose_(0, 1) + W = W.float() + + tick = time.time() + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(self.columns, device=W.device, dtype=torch.int) + // weight_quant_args.group_size + ) + + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, self.H, perm = self._apply_activation_ordering(W, self.H) + self._update_quantization_parameters(weight_quant_args, W) + + # use identity g_idx (invert permutation later) + + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + self._update_quantization_parameters(weight_quant_args, W) + W, self.H, perm = self._apply_activation_ordering(W, self.H) + + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] + + scale = self.layer.weight_scale + zero_point = self.layer.weight_zero_point + + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + + # mask dead hessian values + dead = torch.diag(self.H) == 0 + self.H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros(self.rows, device=self.dev) + + # compute inverse hessian in place to save memory + damp = percdamp * torch.mean(torch.diag(self.H)) + diag = torch.arange(self.columns, device=self.dev) + self.H[diag, diag] += damp + self.H = torch.linalg.cholesky(self.H) + self.H = torch.cholesky_inverse(self.H) + self.H = torch.linalg.cholesky(self.H, upper=True) + Hinv = self.H + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = w.clone() + + # quantize column + if strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + self.layer.quantization_scheme.weights, + ) + elif strategy == QuantizationStrategy.CHANNEL: + q = fake_quantize( + q, + scale[:, 0], + zero_point[:, 0], + weight_quant_args, + ) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + column_idx = i1 + i + group_index = g_idx[column_idx] + + # Since we're only applying quantization to a slice, this + # ends up being a channelwise application + altered_qargs = copy(weight_quant_args) + altered_qargs.strategy = QuantizationStrategy.CHANNEL + q = fake_quantize( + q, + scale[:, group_index], + zero_point[:, group_index], + altered_qargs, + ) + else: + raise ValueError( + "Quantization strategy is not supported for GPTQ: " + f"{strategy}" + ) + + # propagate column error + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if preserve_zeros: + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] + else: + W1[:, i:] -= w1_err + Err1[:, i] = err1 + + # propagate block error + W[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 + + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err + + if "METRIC" in logger._core.levels.keys(): + self._log_metrics(tick, Losses) + + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] + + # only save g_idx if mapping is not identity + update_parameter_data(self.layer, g_idx, "weight_g_idx") + + if isinstance(self.layer, transformers.Conv1D): + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) + + # This is a bit hacky, but FSDP updates only work if we change + # the weight in place, clone() or direct assignment won't work + self.layer.weight -= self.layer.weight + self.layer.weight += W + + if is_module_offloaded(self.layer): + device = get_offloaded_device(self.layer) + update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) + self.layer._hf_hook.post_forward(self.layer, None) + + def free(self): + """ + Free the Hessian memory after the layer is complete + """ + delattr(self, "H") + super().free() + + def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor): + """ + Update layer quantization parameters with potentially permuted weight + + :param args: quantization arguments + :param W: weight to calculate quantization parameters from + """ + observer = args.get_observer() + _scale, _zero_point = observer(W, g_idx=None) + update_parameter_data(self.layer, _scale, "weight_scale") + update_parameter_data(self.layer, _zero_point, "weight_zero_point") + + def _apply_activation_ordering( + self, W: torch.Tensor, H: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Permute weight and hessian in order of greatest outupt activations + + :param W: weight to permute + """ + perm = torch.argsort(torch.diag(H), descending=True) + return W[:, perm], H[perm][:, perm], perm + + def _log_metrics(self, start_tick: float, losses: torch.Tensor): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + patch = logger.patch(lambda r: r.update(function="compress")) + patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) + patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + patch.log( + "METRIC", + f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", + ) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 266acf973..0305c04df 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -22,8 +22,14 @@ from urllib.parse import urlparse import numpy +import torch from loguru import logger +from compressed_tensors.quantization import ( + disable_quantization, + enable_quantization, +) + __all__ = [ "ALL_TOKEN", "ALL_PRUNABLE_TOKEN", @@ -59,6 +65,7 @@ "is_package_available", "import_from_path", "getattr_chain", + "DisableKVCache", ] @@ -1041,3 +1048,41 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: res = getattr(res, attr_name) return res + + +class DisableKVCache: + def __init__(self, model: torch.nn.Module): + if hasattr(model.config, "use_cache"): + self.config = model.config + + # MllamaConfig + elif hasattr(model.config, "text_config") and hasattr( + model.config.text_config, "use_cache" + ): + self.config = model.config.text_config + + # unknown config structure + else: + raise NotImplementedError( + f"Cannot find `use_cache` for config of type {type(model.config)}" + ) + + self.restore_value = self.config.use_cache + + def __enter__(self): + self.restore_value = self.config.use_cache + self.config.use_cache = False + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.config.use_cache = self.restore_value + + +class DisableQuantization: + def __init__(self, model: torch.nn.Module): + self.model = model + + def __enter__(self): + self.model.apply(disable_quantization) + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.model.apply(enable_quantization) \ No newline at end of file From e3a98cc12c6840fb4836d71d1b77d934512a46f0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 17:08:33 +0000 Subject: [PATCH 007/285] WIP: begin quantize_weight --- .../modifiers/quantization/gptq/base.py | 69 +++---- .../utils/{compress.py => gptq_quantize.py} | 172 +++++++++--------- 2 files changed, 118 insertions(+), 123 deletions(-) rename src/llmcompressor/modifiers/quantization/gptq/utils/{compress.py => gptq_quantize.py} (66%) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index fd1c42116..9e9c31d6e 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -215,10 +215,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: return True - def on_end(self): - self.register_hooks(state.model, layers) - self.dummy_forward() ??? - def on_finalize(self, state: "State", **kwargs) -> bool: """ disable the quantization observers used by the OBCQ algorithm @@ -234,49 +230,32 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu layers = layers.values() for name, module in model.named_modules(): - quant_args = getattr_chain(module, "quantization_scheme.weights", None) - if quant_args is not None: - module._gptq_pre_hook = module.register_forward_pre_hook( - partial(self.target_pre_forward, name, quant_args)) - module._gptq_post_hook = module.register_forward_hook( - partial(self.target_post_forward, name, quant_args)) + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + pre_hook = partial(self.target_pre_forward, name) + post_hook = partial(self.target_post_forward, name) + module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook) if module in layers.values(): - module._gptq_pre_hook = module.register_forward_pre_hook( - partial(self.layer_pre_forward, name)) - module._gptq_post_hook = module.register_forward_hook( - partial(self.layer_post_forward, name)) + pre_hook = partial(self.layer_pre_forward, name) + post_hook = partial(self.layer_post_forward, name) + module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook) def calibration_forward(self, model: torch.nn.Module, data: torch.utils.data.Dataloader): all_data = torch.cat([batch for batch in data], dim=0) with DisableKVCache(model), DisableQuantization(model): model(all_data) - def target_pre_forward(self, name: str, quant_args: QuantizationScheme, module: torch.nn.Module, args, kwargs): + def target_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): if self.true_sequential: # compress first so output is from quantized weights - logger.info(f"Compressing {name}...") - gptq_compress( - module, - args, - kwargs, - quant_args, - block_size=self.block_size, - percdamp=self.dampening_frac, - ) + self.quantize_module(name, module, args) - def target_post_forward(self, name: str, quant_args: QuantizationScheme, module: torch.nn.Module, args, kwargs, output): + def target_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): if not self.true_sequential: # compress after so output is from unquantized weights - logger.info(f"Compressing {name}...") - gptq_compress( - module, - args, - kwargs, - quant_args, - block_size=self.block_size, - percdamp=self.dampening_frac, - ) + self.quantize_module(name, module, args) def layer_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): logger.info(f"\n===== Compressing layer {self.layer_index}/{self.num_layers} =====") @@ -291,6 +270,28 @@ def layer_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, o self.layer_index += 1 return output + def quantize_module(self, name, module, inp): + logger.info(f"Compressing {name}...") + + quant_args = getattr_chain(module, "quantization_scheme.weights") + # with onloaded weight + quantized_weight, scale, zero_point, g_idx = quantize_weight( + module.weight.data, + inp, + quant_args, + block_size=self.block_size, + percdamp=self.dampening_frac, + module_class=type(module), + ) + + # This is a bit hacky, but FSDP updates only work if we change + # the weight in place, clone() or direct assignment won't work + self.layer.weight -= self.layer.weight + self.layer.weight += lerp(module.weight.data, quantized_weight, self.alpha) + update_parameter_data(module, scale, "weight_scale") + update_parameter_data(module, zero_point, "weight_zero_point") + update_parameter_data(module, g_idx, "weight_g_idx") + def remove_hooks(self, module: torch.nn.Module, recurse: bool = True): if hasattr(module, "_gptq_pre_hook"): module._gptq_pre_hook.remove() diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py similarity index 66% rename from src/llmcompressor/modifiers/quantization/gptq/utils/compress.py rename to src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 05a856072..66f111ffa 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/compress.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,85 +1,82 @@ -import torch - - def add_batch(self, inp: torch.Tensor, out: torch.Tensor): - """ - Add a batch of layer input and output data to the Hessian calculation - - :param inp: tensor containing layer input - :param out: tensor containing layer output - """ - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - inp = inp.to(dtype=self.H.dtype) - inp = math.sqrt(2 / self.nsamples) * inp - self.H += inp.matmul(inp.t()) - - def compress( - self, - blocksize: int = 128, - percdamp: float = 0.01, - ): - """ - Run pruning and quantization(if applicable) on the layer up to the target - sparsity value. - - :param blocksize: Number of columns to compress in one pass - :param percdamp: Amount of dampening to apply to H, as a fraction of the - diagonal norm - """ - args_loc = "quantization_scheme.weights" - weight_quant_args = getattr_chain(self.layer, args_loc, None) - if weight_quant_args is None: - logger.debug(f"Skipping unquantized layer {self.name}...") - return - - if is_module_offloaded(self.layer): - self.layer._hf_hook.pre_forward(self.layer) - - strategy = weight_quant_args.strategy - actorder = weight_quant_args.actorder - final_shape = self.layer.weight.shape - final_dtype = self.layer.weight.dtype - W = self.layer.weight.data.clone() - - # standardize shape and dtype - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - elif isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.float() +from typing import Any - tick = time.time() - - if strategy == QuantizationStrategy.GROUP: - # mapping from column index to group index - g_idx = ( - torch.arange(self.columns, device=W.device, dtype=torch.int) - // weight_quant_args.group_size - ) +import time +import math +import torch +from compressed_tensors.quantization import QuantizationArguments + + +def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: + inp = inp.to(device=device) + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + if module_class in (torch.nn.Linear, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + nsamples = inp.shape[0] + + inp = inp.to(dtype=torch.float32) + inp = math.sqrt(2 / nsamples) * inp + return inp.matmul(inp.t()) + + +def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + return H + + +def quantize_weight( + weight: torch.Tensor, + inp: torch.Tensor, + quant_args: QuantizationArguments, + block_size: int = 128, + percdamp: float = 0.01, + module_class = torch.nn.Linear, +) -> Tuple[torch.nn.Parameter, ]: + strategy = quant_args.strategy + actorder = quant_args.actorder + final_shape = weight.shape + final_dtype = weight.dtype + W = weight.data.clone() + + # standardize shape and dtype + if module_class == torch.nn.Conv2d: + W = W.flatten(1) + elif module_class == transformers.Conv1D: + W.transpose_(0, 1) + W = W.to(dtype=torch.float32) + + tick = time.time() + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(self.columns, device=W.device, dtype=torch.int) + // weight_quant_args.group_size + ) - if actorder == ActivationOrdering.GROUP: - # permute by activation order first, then update groups - W, self.H, perm = self._apply_activation_ordering(W, self.H) - self._update_quantization_parameters(weight_quant_args, W) + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, self.H, perm = self._apply_activation_ordering(W, self.H) + self._update_quantization_parameters(weight_quant_args, W) - # use identity g_idx (invert permutation later) + # use identity g_idx (invert permutation later) - elif actorder == ActivationOrdering.WEIGHT: - # update groups first, then permute by activation order - self._update_quantization_parameters(weight_quant_args, W) - W, self.H, perm = self._apply_activation_ordering(W, self.H) + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + self._update_quantization_parameters(weight_quant_args, W) + W, self.H, perm = self._apply_activation_ordering(W, self.H) - # permute g_idx to maintain identity mapping after unpermutation - g_idx = g_idx[perm] + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] scale = self.layer.weight_scale zero_point = self.layer.weight_zero_point @@ -93,21 +90,18 @@ def compress( else None ) - # mask dead hessian values - dead = torch.diag(self.H) == 0 - self.H[dead, dead] = 1 - W[:, dead] = 0 - Losses = torch.zeros(self.rows, device=self.dev) # compute inverse hessian in place to save memory - damp = percdamp * torch.mean(torch.diag(self.H)) - diag = torch.arange(self.columns, device=self.dev) - self.H[diag, diag] += damp - self.H = torch.linalg.cholesky(self.H) - self.H = torch.cholesky_inverse(self.H) - self.H = torch.linalg.cholesky(self.H, upper=True) - Hinv = self.H + H = compute_hessian(inp, module_class, device=device) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # TODO: check in place + Hinv = invert_hessian(H, percdamp) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): From bc9b3bcd889de1557c5fb1868b71a4274d2658f8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 17:20:32 +0000 Subject: [PATCH 008/285] WIP --- .../quantization/gptq/utils/gptq_quantize.py | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 66f111ffa..f0957b130 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -3,7 +3,8 @@ import time import math import torch -from compressed_tensors.quantization import QuantizationArguments +import transformers +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: @@ -36,8 +37,8 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: def quantize_weight( weight: torch.Tensor, inp: torch.Tensor, - quant_args: QuantizationArguments, - block_size: int = 128, + quant_args: QuantizationArgs, + blocksize: int = 128, percdamp: float = 0.01, module_class = torch.nn.Linear, ) -> Tuple[torch.nn.Parameter, ]: @@ -45,7 +46,10 @@ def quantize_weight( actorder = quant_args.actorder final_shape = weight.shape final_dtype = weight.dtype + num_columns = weight.shape[1] W = weight.data.clone() + + H = compute_hessian(inp, module_class, device=device) # standardize shape and dtype if module_class == torch.nn.Conv2d: @@ -56,31 +60,30 @@ def quantize_weight( tick = time.time() + scale, zero_point = compute_scale_zeropoint(W) + if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index g_idx = ( - torch.arange(self.columns, device=W.device, dtype=torch.int) - // weight_quant_args.group_size + torch.arange(num_columns, device=W.device, dtype=torch.int) + // quant_args.group_size ) if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups - W, self.H, perm = self._apply_activation_ordering(W, self.H) - self._update_quantization_parameters(weight_quant_args, W) + W, H, perm = _apply_activation_ordering(W, H) + scale, zero_point = _update_quantization_parameters(quant_args, W) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - self._update_quantization_parameters(weight_quant_args, W) - W, self.H, perm = self._apply_activation_ordering(W, self.H) + scale, zero_point = _update_quantization_parameters(quant_args, W) + W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] - scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point - # sparsity mask sparsity = tensor_sparsity(W) preserve_zeros = sparsity >= SPARSITY_THRESHOLD @@ -90,22 +93,20 @@ def quantize_weight( else None ) - Losses = torch.zeros(self.rows, device=self.dev) - - # compute inverse hessian in place to save memory - H = compute_hessian(inp, module_class, device=device) + Losses = torch.zeros(num_columns, device=weight.device) # mask dead hessian values dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 + # compute inverse hessian in place to save memory # TODO: check in place Hinv = invert_hessian(H, percdamp) # See section 3.4 of https://arxiv.org/abs/2203.07259 - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) count = i2 - i1 W1 = W[:, i1:i2].clone() @@ -128,14 +129,14 @@ def quantize_weight( q, scale, zero_point, - self.layer.quantization_scheme.weights, + quant_args, ) elif strategy == QuantizationStrategy.CHANNEL: q = fake_quantize( q, scale[:, 0], zero_point[:, 0], - weight_quant_args, + quant_args, ) elif strategy == QuantizationStrategy.GROUP: # get the group index for the current column @@ -144,7 +145,7 @@ def quantize_weight( # Since we're only applying quantization to a slice, this # ends up being a channelwise application - altered_qargs = copy(weight_quant_args) + altered_qargs = copy(quant_args) altered_qargs.strategy = QuantizationStrategy.CHANNEL q = fake_quantize( q, From b77c7bf3effbd8b96ee65f5cd2e888a1a9d205a4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 18:55:56 +0000 Subject: [PATCH 009/285] WIP --- .../modifiers/quantization/gptq/base.py | 43 +++++++++---- .../quantization/gptq/utils/gptq_quantize.py | 60 ++++++++++--------- .../quantization/gptq/utils/helpers.py | 12 +++- 3 files changed, 72 insertions(+), 43 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 9e9c31d6e..8bec38fa0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch +import contextlib from functools import partial from compressed_tensors.quantization import ( QuantizationScheme, @@ -18,6 +19,7 @@ from llmcompressor.modifiers.quantization.gptq.utils import ( GPTQWrapper, get_output_error, + gptq_hook ) from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -194,6 +196,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.quantization_modifier_.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") + + # after lifecycle refactor, all of this moves to pre_batch # find layers (used for printing even if true_sequential=True) # if no targets are provided, default to the modules that shouldn't be @@ -224,6 +228,8 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self.quantization_modifier_: self.quantization_modifier_.finalize(state, **kwargs) + self.remove_gptq_hooks(state.model) + return True def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): @@ -247,25 +253,28 @@ def calibration_forward(self, model: torch.nn.Module, data: torch.utils.data.Dat with DisableKVCache(model), DisableQuantization(model): model(all_data) + @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): if self.true_sequential: # compress first so output is from quantized weights self.quantize_module(name, module, args) + @gptq_hook def target_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): if not self.true_sequential: # compress after so output is from unquantized weights self.quantize_module(name, module, args) + @gptq_hook def layer_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): logger.info(f"\n===== Compressing layer {self.layer_index}/{self.num_layers} =====") + @gptq_hook def layer_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): - self.remove_hooks(module) - if not self.true_sequential: # rerun with (now) quantized weights - output = module(*args, **kwargs) + with self.disable_hooks(): + output = module(*args, **kwargs) self.layer_index += 1 return output @@ -283,16 +292,23 @@ def quantize_module(self, name, module, inp): percdamp=self.dampening_frac, module_class=type(module), ) - - # This is a bit hacky, but FSDP updates only work if we change - # the weight in place, clone() or direct assignment won't work - self.layer.weight -= self.layer.weight - self.layer.weight += lerp(module.weight.data, quantized_weight, self.alpha) - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - update_parameter_data(module, g_idx, "weight_g_idx") - - def remove_hooks(self, module: torch.nn.Module, recurse: bool = True): + + weight = lerp(module.weight.data, quantized_weight, self.alpha) + + update_prefix_dict(self.layer, "weight", weight) + update_parameter_data(module, scale, "weight_scale") + update_parameter_data(module, zero_point, "weight_zero_point") + update_parameter_data(module, g_idx, "weight_g_idx") + + @contextlib.contextmanager + def disable_hooks(self): + try: + self.hooks_disabled = True + yield + finally: + self.hooks_disabled = False + + def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): if hasattr(module, "_gptq_pre_hook"): module._gptq_pre_hook.remove() delattr(module, "_gptq_pre_hook") @@ -305,6 +321,7 @@ def remove_hooks(self, module: torch.nn.Module, recurse: bool = True): for child_module in module.children(): self.remove_hooks(child_module) + def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index f0957b130..36ab9055f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -182,7 +182,7 @@ def quantize_weight( W[:, i2:] -= w_err if "METRIC" in logger._core.levels.keys(): - self._log_metrics(tick, Losses) + _log_metrics(tick, Losses) if strategy == QuantizationStrategy.GROUP: if actorder == ActivationOrdering.WEIGHT: @@ -213,6 +213,8 @@ def quantize_weight( update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) self.layer._hf_hook.post_forward(self.layer, None) + return W, scale, zero_point, g_idx + def free(self): """ Free the Hessian memory after the layer is complete @@ -243,31 +245,31 @@ def _apply_activation_ordering( perm = torch.argsort(torch.diag(H), descending=True) return W[:, perm], H[perm][:, perm], perm - def _log_metrics(self, start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) +def _log_metrics(self, start_tick: float, losses: torch.Tensor): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + patch = logger.patch(lambda r: r.update(function="compress")) + patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) + patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + patch.log( + "METRIC", + f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", + ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index 58fedc634..f226e41c0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -2,7 +2,7 @@ import torch -__all__ = ["get_output_error"] +__all__ = ["get_output_error", "gptq_hook"] def get_output_error( @@ -49,3 +49,13 @@ def get_output_error( for unq, q in zip(unquantized_outputs, quantized_outputs) ] ) / len(unquantized_outputs) + + +def gptq_hook(func): + def wrapped(self, *args, **kwargs): + if self.hooks_disabled: + return + + func(self, *args, **kwargs) + + return wrapped \ No newline at end of file From 7be5aed7e2996ed4d855ae6f246443784cd43c80 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 20:59:39 +0000 Subject: [PATCH 010/285] wip --- .../modifiers/quantization/gptq/base.py | 9 +++- .../quantization/gptq/utils/gptq_quantize.py | 43 ++++++++++--------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 8bec38fa0..b5da7cce0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -31,6 +31,13 @@ qat_active, ) +from compressed_tensors.utils import ( + get_offloaded_device, + is_module_offloaded, + update_parameter_data, + update_prefix_dict, +) + __all__ = ["GPTQModifier"] @@ -293,7 +300,7 @@ def quantize_module(self, name, module, inp): module_class=type(module), ) - weight = lerp(module.weight.data, quantized_weight, self.alpha) + weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) update_prefix_dict(self.layer, "weight", weight) update_parameter_data(module, scale, "weight_scale") diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 36ab9055f..d5b1efef1 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,10 +1,16 @@ -from typing import Any +from typing import Tuple, Union import time import math import torch import transformers -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering +from copy import copy +from loguru import logger +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering, fake_quantize +from llmcompressor.utils.metric_logging import ( + get_GPU_memory_usage, + get_layer_size_bytes, +) def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: @@ -41,7 +47,7 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class = torch.nn.Linear, -) -> Tuple[torch.nn.Parameter, ]: +) -> Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape @@ -49,7 +55,7 @@ def quantize_weight( num_columns = weight.shape[1] W = weight.data.clone() - H = compute_hessian(inp, module_class, device=device) + H = compute_hessian(inp, module_class, device=weight.device) # standardize shape and dtype if module_class == torch.nn.Conv2d: @@ -60,8 +66,6 @@ def quantize_weight( tick = time.time() - scale, zero_point = compute_scale_zeropoint(W) - if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index g_idx = ( @@ -72,18 +76,23 @@ def quantize_weight( if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, H, perm = _apply_activation_ordering(W, H) - scale, zero_point = _update_quantization_parameters(quant_args, W) + scale, zero_point = compute_scale_zeropoint(W, quant_args) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - scale, zero_point = _update_quantization_parameters(quant_args, W) + scale, zero_point = compute_scale_zeropoint(W, quant_args) W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] + else: + scale, zero_point = compute_scale_zeropoint(W, quant_args) + else: + scale, zero_point = compute_scale_zeropoint(W, quant_args) + # sparsity mask sparsity = tensor_sparsity(W) preserve_zeros = sparsity >= SPARSITY_THRESHOLD @@ -184,6 +193,7 @@ def quantize_weight( if "METRIC" in logger._core.levels.keys(): _log_metrics(tick, Losses) + has_gidx = False if strategy == QuantizationStrategy.GROUP: if actorder == ActivationOrdering.WEIGHT: # restore original permutation @@ -197,22 +207,15 @@ def quantize_weight( g_idx = g_idx[invperm] # only save g_idx if mapping is not identity - update_parameter_data(self.layer, g_idx, "weight_g_idx") + has_gidx = True + + if not has_gidx: + g_idx = None - if isinstance(self.layer, transformers.Conv1D): + if module_class == transformers.Conv1D: W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) - # This is a bit hacky, but FSDP updates only work if we change - # the weight in place, clone() or direct assignment won't work - self.layer.weight -= self.layer.weight - self.layer.weight += W - - if is_module_offloaded(self.layer): - device = get_offloaded_device(self.layer) - update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) - self.layer._hf_hook.post_forward(self.layer, None) - return W, scale, zero_point, g_idx def free(self): From e01094fed95d4be087c702f841cf687c76347690 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 22:01:27 +0000 Subject: [PATCH 011/285] compilable --- .../modifiers/quantization/gptq/base.py | 48 +++++++++++++-- .../quantization/gptq/utils/gptq_quantize.py | 60 ++++++------------- .../quantization/gptq/utils/helpers.py | 49 ++++++++++++++- src/llmcompressor/utils/helpers.py | 17 +++++- 4 files changed, 126 insertions(+), 48 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b5da7cce0..490b46450 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -21,10 +21,13 @@ get_output_error, gptq_hook ) +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight +from llmcompressor.modifiers.quantization.gptq.utils.helpers import LogMetrics from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, getattr_chain +from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain +from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -203,8 +206,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.quantization_modifier_.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - - # after lifecycle refactor, all of this moves to pre_batch # find layers (used for printing even if true_sequential=True) # if no targets are provided, default to the modules that shouldn't be @@ -216,12 +217,14 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.num_layers = len(layers) # add hooks to targets and layers + # after lifecycle refactor, move this to pre_batch self.register_hooks(state.model, layers) # apply calibration and trigger hooks (hooks are self removing) self.calibration_forward(state.model, state.data.calib) # freeze quantization + # after lifecycle refactor, move this to post_batch state.model.apply(freeze_module_quantization) return True @@ -291,7 +294,8 @@ def quantize_module(self, name, module, inp): quant_args = getattr_chain(module, "quantization_scheme.weights") # with onloaded weight - quantized_weight, scale, zero_point, g_idx = quantize_weight( + with OnloadModule(module), LogMetrics(module) as logger: + losses, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, quant_args, @@ -302,10 +306,13 @@ def quantize_module(self, name, module, inp): weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) - update_prefix_dict(self.layer, "weight", weight) + if is_module_offloaded(module): + update_prefix_dict(self.layer, "weight", weight) update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") + + logger.set_losses(losses) @contextlib.contextmanager def disable_hooks(self): @@ -329,6 +336,37 @@ def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): self.remove_hooks(child_module) + def _log_metrics(start_tick: float, losses: torch.Tensor): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + patch = logger.patch(lambda r: r.update(function="compress")) + patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) + patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + patch.log( + "METRIC", + f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", + ) + + + def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index d5b1efef1..4ecdfe837 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -47,7 +47,7 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class = torch.nn.Linear, -) -> Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape @@ -102,7 +102,7 @@ def quantize_weight( else None ) - Losses = torch.zeros(num_columns, device=weight.device) + losses = torch.zeros(num_columns, device=weight.device) # mask dead hessian values dead = torch.diag(H) == 0 @@ -121,7 +121,7 @@ def quantize_weight( W1 = W[:, i1:i2].clone() Q1 = torch.zeros_like(W1) Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] if preserve_zeros: @@ -170,7 +170,7 @@ def quantize_weight( # propagate column error Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 + losses1[:, i] = (w - q) ** 2 / d**2 err1 = (w - q) / d w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) @@ -182,7 +182,7 @@ def quantize_weight( # propagate block error W[:, i1:i2] = Q1 - Losses += torch.sum(Losses1, 1) / 2 + losses += torch.sum(losses1, 1) / 2 w_err = Err1.matmul(Hinv[i1:i2, i2:]) if preserve_zeros: @@ -190,9 +190,6 @@ def quantize_weight( else: W[:, i2:] -= w_err - if "METRIC" in logger._core.levels.keys(): - _log_metrics(tick, Losses) - has_gidx = False if strategy == QuantizationStrategy.GROUP: if actorder == ActivationOrdering.WEIGHT: @@ -216,39 +213,20 @@ def quantize_weight( W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) - return W, scale, zero_point, g_idx - - def free(self): - """ - Free the Hessian memory after the layer is complete - """ - delattr(self, "H") - super().free() - - def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor): - """ - Update layer quantization parameters with potentially permuted weight - - :param args: quantization arguments - :param W: weight to calculate quantization parameters from - """ - observer = args.get_observer() - _scale, _zero_point = observer(W, g_idx=None) - update_parameter_data(self.layer, _scale, "weight_scale") - update_parameter_data(self.layer, _zero_point, "weight_zero_point") - - def _apply_activation_ordering( - self, W: torch.Tensor, H: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Permute weight and hessian in order of greatest outupt activations - - :param W: weight to permute - """ - perm = torch.argsort(torch.diag(H), descending=True) - return W[:, perm], H[perm][:, perm], perm - -def _log_metrics(self, start_tick: float, losses: torch.Tensor): + return losses, W, scale, zero_point, g_idx + +def _apply_activation_ordering( + W: torch.Tensor, H: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Permute weight and hessian in order of greatest outupt activations + + :param W: weight to permute + """ + perm = torch.argsort(torch.diag(H), descending=True) + return W[:, perm], H[perm][:, perm], perm + +def _log_metrics(start_tick: float, losses: torch.Tensor): """ Log metrics related to compression algorithm diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index f226e41c0..c15816892 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -1,6 +1,10 @@ from typing import Any, Iterable, List, Tuple, Union +import time import torch +from loguru import logger + +from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes __all__ = ["get_output_error", "gptq_hook"] @@ -58,4 +62,47 @@ def wrapped(self, *args, **kwargs): func(self, *args, **kwargs) - return wrapped \ No newline at end of file + return wrapped + + +class LogMetrics: + def __init__(self, module: torch.nn.Module): + self.module = module + self.start_tick = None + self.losses = None + + def set_losses(self, losses: torch.Tensor): + self.losses = losses + + def __enter__(self): + self.start_tick = time.time() + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + patch = logger.patch(lambda r: r.update(function="compress")) + + if self.start_tick is not None: + patch.log("METRIC", "time %.2f" % (time.time() - self.start_tick)) + if self.losses is not None: + patch.log("METRIC", "error %.2f" % torch.sum(self.losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + compressed_size = get_layer_size_bytes(self.module) + patch.log("METRIC", f"Compressed layer size: {compressed_size} MB") diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 0305c04df..b46685110 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -29,6 +29,7 @@ disable_quantization, enable_quantization, ) +from compressed_tensors import is_module_offloaded __all__ = [ "ALL_TOKEN", @@ -1085,4 +1086,18 @@ def __enter__(self): self.model.apply(disable_quantization) def __exit__(self, _exc_type, _exc_val, _exc_tb): - self.model.apply(enable_quantization) \ No newline at end of file + self.model.apply(enable_quantization) + + +class OnloadModule: + def __init__(self, module: torch.nn.Module): + self.module = module + self.is_module_offloaded = is_module_offloaded(self.module) + + def __enter__(self): + if self.is_module_offloaded: + self.module._hf_hook.pre_forward(self.module) + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + if self.is_module_offloaded: + self.module._hf_hook.post_forward(self.module, None) \ No newline at end of file From ad9f5a8d6100027ab045b0d44cc068d59c041c33 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 22:02:00 +0000 Subject: [PATCH 012/285] compilable --- .../modifiers/quantization/gptq/base.py | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 490b46450..cfcdfc529 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -335,38 +335,6 @@ def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): for child_module in module.children(): self.remove_hooks(child_module) - - def _log_metrics(start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) - - - def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, From e4ee0af5c3c32be52329f39d2f51f83a857cb656 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 22:52:13 +0000 Subject: [PATCH 013/285] wip --- .../modifiers/quantization/gptq/base.py | 73 ++-- .../quantization/gptq/utils/__init__.py | 2 +- .../quantization/gptq/utils/gptq_quantize.py | 29 -- .../quantization/gptq/utils/gptq_wrapper.py | 341 ------------------ .../quantization/gptq/utils/helpers.py | 2 +- 5 files changed, 50 insertions(+), 397 deletions(-) delete mode 100644 src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index cfcdfc529..804359957 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -6,28 +6,22 @@ from functools import partial from compressed_tensors.quantization import ( QuantizationScheme, - disable_quantization, - enable_quantization, freeze_module_quantization, ) from loguru import logger from pydantic import Field, field_validator -from torch.nn import Module from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory from llmcompressor.modifiers.quantization.gptq.utils import ( - GPTQWrapper, get_output_error, gptq_hook ) from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.helpers import LogMetrics -from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.utils.fsdp.context import fix_fsdp_module_name from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain -from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -123,6 +117,11 @@ class GPTQModifier(Modifier): num_calibration_steps: Optional[int] = None scheme: Optional[Union[str, Dict[str, Any]]] = None + _layer_index: int = 0 + _num_layers: int = 0 + _hooks_disabled: bool = False + quantization_modifier_: Optional[QuantizationModifier] = None + @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: if not value: @@ -137,8 +136,8 @@ def validate_sequential_update(cls, value: bool) -> bool: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.current_layer_index = 0 - self.num_layers = 0 + self._layer_index = 0 + self._num_layers = 0 self.quantization_modifier_ = None def on_initialize_structure(self, state: State, **kwargs): @@ -214,7 +213,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if self.sequential_targets is None: self.sequential_targets = get_no_split_params(state.model) layers = get_layers(self.sequential_targets, state.model) - self.num_layers = len(layers) + self._num_layers = len(layers) # add hooks to targets and layers # after lifecycle refactor, move this to pre_batch @@ -243,8 +242,6 @@ def on_finalize(self, state: "State", **kwargs) -> bool: return True def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): - layers = layers.values() - for name, module in model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) @@ -256,37 +253,63 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) + + def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + import torch.nn.functional as F + + accumulated_data = {} # Dictionary to accumulate samples per key - def calibration_forward(self, model: torch.nn.Module, data: torch.utils.data.Dataloader): - all_data = torch.cat([batch for batch in data], dim=0) + def pad_tensor(tensor, max_len): + """Pads a tensor to the specified max_len along the second dimension (sequence length).""" + pad_size = max_len - tensor.size(1) # Calculate the padding size + return F.pad(tensor, (0, pad_size), value=0) # Pad on the right with zeros + + for batch in dataloader: + for key, value in batch.items(): + if key not in accumulated_data: + accumulated_data[key] = [] + accumulated_data[key].append(value) # Accumulate values for each key + + # Find maximum length for each key across all samples to ensure matching shapes + max_lengths = {} + for key, tensors in accumulated_data.items(): + max_lengths[key] = max([tensor.size(1) for tensor in tensors]) # Assuming the second dimension is the sequence length + + # Pad and concatenate for each key + concatenated_batch = { + key: torch.cat([pad_tensor(tensor, max_lengths[key]) for tensor in accumulated_data[key]], dim=0) + for key in accumulated_data + } + with DisableKVCache(model), DisableQuantization(model): - model(all_data) + model(**concatenated_batch) @gptq_hook - def target_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): + def target_pre_forward(self, name: str, module: torch.nn.Module, args): if self.true_sequential: # compress first so output is from quantized weights self.quantize_module(name, module, args) @gptq_hook - def target_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): + def target_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any): if not self.true_sequential: # compress after so output is from unquantized weights self.quantize_module(name, module, args) @gptq_hook - def layer_pre_forward(self, name: str, module: torch.nn.Module, args, kwargs): - logger.info(f"\n===== Compressing layer {self.layer_index}/{self.num_layers} =====") + def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): + logger.info(f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====") + breakpoint() @gptq_hook - def layer_post_forward(self, name: str, module: torch.nn.Module, args, kwargs, output): + def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, kwargs: Dict[str, Any], output: Any): if not self.true_sequential: # rerun with (now) quantized weights with self.disable_hooks(): - output = module(*args, **kwargs) + output = module(args, **kwargs) - self.layer_index += 1 + self._layer_index += 1 return output def quantize_module(self, name, module, inp): @@ -317,10 +340,10 @@ def quantize_module(self, name, module, inp): @contextlib.contextmanager def disable_hooks(self): try: - self.hooks_disabled = True + self._hooks_disabled = True yield finally: - self.hooks_disabled = False + self._hooks_disabled = False def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): if hasattr(module, "_gptq_pre_hook"): diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py index a8673dfc2..5703ced46 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa -from .gptq_wrapper import * +from .gptq_quantize import * from .helpers import * diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 4ecdfe837..512741888 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -225,32 +225,3 @@ def _apply_activation_ordering( """ perm = torch.argsort(torch.diag(H), descending=True) return W[:, perm], H[perm][:, perm], perm - -def _log_metrics(start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py deleted file mode 100644 index d53b942eb..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ /dev/null @@ -1,341 +0,0 @@ -import time -from typing import Tuple - -from compressed_tensors.quantization import ( - ActivationOrdering, - QuantizationArgs, - QuantizationStrategy, -) -from compressed_tensors.quantization.lifecycle.forward import fake_quantize - -from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD -from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper -from llmcompressor.pytorch.utils.helpers import tensor_sparsity -from llmcompressor.utils import getattr_chain -from llmcompressor.utils.metric_logging import ( - get_GPU_memory_usage, - get_layer_size_bytes, -) - -try: - import transformers -except ImportError as err: - transformers = None - transformers_err = err - -import math -from copy import copy - -import torch -import torch.nn as nn -from compressed_tensors.utils import ( - get_offloaded_device, - is_module_offloaded, - update_parameter_data, - update_prefix_dict, -) -from loguru import logger - -__all__ = ["GPTQWrapper"] - - -class GPTQWrapper(ModuleCompressionWrapper): - """ - Runs GPTQ on a single module that contains no sub-modules - - Lifecycle: - - add_batch - - compress - - free - - :param name: name of module to run compression on - :param layer: module to run compression on - """ - - def __init__(self, name, layer): - super().__init__(name=name, layer=layer) - - # for Hessian calculation - self.register_buffer( - "H", - torch.zeros( - (self.columns, self.columns), device=self.dev, dtype=torch.float32 - ), - ) - - def add_batch(self, inp: torch.Tensor, out: torch.Tensor): - """ - Add a batch of layer input and output data to the Hessian calculation - - :param inp: tensor containing layer input - :param out: tensor containing layer output - """ - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - inp = inp.to(dtype=self.H.dtype) - inp = math.sqrt(2 / self.nsamples) * inp - self.H += inp.matmul(inp.t()) - - def compress( - self, - blocksize: int = 128, - percdamp: float = 0.01, - ): - """ - Run pruning and quantization(if applicable) on the layer up to the target - sparsity value. - - :param blocksize: Number of columns to compress in one pass - :param percdamp: Amount of dampening to apply to H, as a fraction of the - diagonal norm - """ - args_loc = "quantization_scheme.weights" - weight_quant_args = getattr_chain(self.layer, args_loc, None) - if weight_quant_args is None: - logger.debug(f"Skipping unquantized layer {self.name}...") - return - - if is_module_offloaded(self.layer): - self.layer._hf_hook.pre_forward(self.layer) - - strategy = weight_quant_args.strategy - actorder = weight_quant_args.actorder - final_shape = self.layer.weight.shape - final_dtype = self.layer.weight.dtype - W = self.layer.weight.data.clone() - - # standardize shape and dtype - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - elif isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.float() - - tick = time.time() - - if strategy == QuantizationStrategy.GROUP: - # mapping from column index to group index - g_idx = ( - torch.arange(self.columns, device=W.device, dtype=torch.int) - // weight_quant_args.group_size - ) - - if actorder == ActivationOrdering.GROUP: - # permute by activation order first, then update groups - W, self.H, perm = self._apply_activation_ordering(W, self.H) - self._update_quantization_parameters(weight_quant_args, W) - - # use identity g_idx (invert permutation later) - - elif actorder == ActivationOrdering.WEIGHT: - # update groups first, then permute by activation order - self._update_quantization_parameters(weight_quant_args, W) - W, self.H, perm = self._apply_activation_ordering(W, self.H) - - # permute g_idx to maintain identity mapping after unpermutation - g_idx = g_idx[perm] - - scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point - - # sparsity mask - sparsity = tensor_sparsity(W) - preserve_zeros = sparsity >= SPARSITY_THRESHOLD - W_nz_mask = ( - (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() - if preserve_zeros - else None - ) - - # mask dead hessian values - dead = torch.diag(self.H) == 0 - self.H[dead, dead] = 1 - W[:, dead] = 0 - - Losses = torch.zeros(self.rows, device=self.dev) - - # compute inverse hessian in place to save memory - damp = percdamp * torch.mean(torch.diag(self.H)) - diag = torch.arange(self.columns, device=self.dev) - self.H[diag, diag] += damp - self.H = torch.linalg.cholesky(self.H) - self.H = torch.cholesky_inverse(self.H) - self.H = torch.linalg.cholesky(self.H, upper=True) - Hinv = self.H - - # See section 3.4 of https://arxiv.org/abs/2203.07259 - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - if preserve_zeros: - W1_nz_mask = W_nz_mask[:, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - q = w.clone() - - # quantize column - if strategy == QuantizationStrategy.TENSOR: - q = fake_quantize( - q, - scale, - zero_point, - self.layer.quantization_scheme.weights, - ) - elif strategy == QuantizationStrategy.CHANNEL: - q = fake_quantize( - q, - scale[:, 0], - zero_point[:, 0], - weight_quant_args, - ) - elif strategy == QuantizationStrategy.GROUP: - # get the group index for the current column - column_idx = i1 + i - group_index = g_idx[column_idx] - - # Since we're only applying quantization to a slice, this - # ends up being a channelwise application - altered_qargs = copy(weight_quant_args) - altered_qargs.strategy = QuantizationStrategy.CHANNEL - q = fake_quantize( - q, - scale[:, group_index], - zero_point[:, group_index], - altered_qargs, - ) - else: - raise ValueError( - "Quantization strategy is not supported for GPTQ: " - f"{strategy}" - ) - - # propagate column error - Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 - - err1 = (w - q) / d - w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - if preserve_zeros: - W1[:, i:] -= w1_err * W1_nz_mask[:, i:] - else: - W1[:, i:] -= w1_err - Err1[:, i] = err1 - - # propagate block error - W[:, i1:i2] = Q1 - Losses += torch.sum(Losses1, 1) / 2 - - w_err = Err1.matmul(Hinv[i1:i2, i2:]) - if preserve_zeros: - W[:, i2:] -= w_err * W_nz_mask[:, i2:] - else: - W[:, i2:] -= w_err - - if "METRIC" in logger._core.levels.keys(): - self._log_metrics(tick, Losses) - - if strategy == QuantizationStrategy.GROUP: - if actorder == ActivationOrdering.WEIGHT: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - - elif actorder == ActivationOrdering.GROUP: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - g_idx = g_idx[invperm] - - # only save g_idx if mapping is not identity - update_parameter_data(self.layer, g_idx, "weight_g_idx") - - if isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.reshape(final_shape).to(final_dtype) - - # This is a bit hacky, but FSDP updates only work if we change - # the weight in place, clone() or direct assignment won't work - self.layer.weight -= self.layer.weight - self.layer.weight += W - - if is_module_offloaded(self.layer): - device = get_offloaded_device(self.layer) - update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) - self.layer._hf_hook.post_forward(self.layer, None) - - def free(self): - """ - Free the Hessian memory after the layer is complete - """ - delattr(self, "H") - super().free() - - def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor): - """ - Update layer quantization parameters with potentially permuted weight - - :param args: quantization arguments - :param W: weight to calculate quantization parameters from - """ - observer = args.get_observer() - _scale, _zero_point = observer(W, g_idx=None) - update_parameter_data(self.layer, _scale, "weight_scale") - update_parameter_data(self.layer, _zero_point, "weight_zero_point") - - def _apply_activation_ordering( - self, W: torch.Tensor, H: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Permute weight and hessian in order of greatest outupt activations - - :param W: weight to permute - """ - perm = torch.argsort(torch.diag(H), descending=True) - return W[:, perm], H[perm][:, perm], perm - - def _log_metrics(self, start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index c15816892..413f5eaca 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -57,7 +57,7 @@ def get_output_error( def gptq_hook(func): def wrapped(self, *args, **kwargs): - if self.hooks_disabled: + if self._hooks_disabled: return func(self, *args, **kwargs) From d9ba539739f4e4cd3acfdde2df36f58e66a6bfc7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 22:52:34 +0000 Subject: [PATCH 014/285] add example --- examples/quantization_w4a16/llama3_example.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index d587a6199..01d9dba8c 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,3 +1,4 @@ +import torch from datasets import load_dataset from transformers import AutoTokenizer @@ -5,8 +6,9 @@ from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot # Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" #MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, From 83a5762c932dc69d6bc9aa714ff39f5f1149b2e1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 23:39:07 +0000 Subject: [PATCH 015/285] wip --- .../modifiers/quantization/gptq/base.py | 24 +- .../quantization/gptq/utils/gptq_quantize.py | 238 +++++++++--------- .../quantization/gptq/utils/helpers.py | 7 +- 3 files changed, 139 insertions(+), 130 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 804359957..3d70b2d40 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -18,7 +18,7 @@ gptq_hook ) from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight -from llmcompressor.modifiers.quantization.gptq.utils.helpers import LogMetrics +from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.utils.fsdp.context import fix_fsdp_module_name from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain @@ -106,6 +106,7 @@ class GPTQModifier(Modifier): """ sequential_update: bool = True + true_sequential: bool = False targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -256,12 +257,12 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + """ import torch.nn.functional as F accumulated_data = {} # Dictionary to accumulate samples per key def pad_tensor(tensor, max_len): - """Pads a tensor to the specified max_len along the second dimension (sequence length).""" pad_size = max_len - tensor.size(1) # Calculate the padding size return F.pad(tensor, (0, pad_size), value=0) # Pad on the right with zeros @@ -281,9 +282,12 @@ def pad_tensor(tensor, max_len): key: torch.cat([pad_tensor(tensor, max_lengths[key]) for tensor in accumulated_data[key]], dim=0) for key in accumulated_data } + """ + + batch = next(iter(dataloader)) with DisableKVCache(model), DisableQuantization(model): - model(**concatenated_batch) + model(**batch) @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): @@ -300,7 +304,6 @@ def target_post_forward(self, name: str, module: torch.nn.Module, args: torch.Te @gptq_hook def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): logger.info(f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====") - breakpoint() @gptq_hook def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, kwargs: Dict[str, Any], output: Any): @@ -312,22 +315,25 @@ def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Ten self._layer_index += 1 return output - def quantize_module(self, name, module, inp): + def quantize_module(self, name, module, args): logger.info(f"Compressing {name}...") + inp = args[0] # Assume that first argument is input (true for most Module types) quant_args = getattr_chain(module, "quantization_scheme.weights") + # with onloaded weight - with OnloadModule(module), LogMetrics(module) as logger: + with OnloadModule(module), MetricsLogger(module) as metrics_logger: losses, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, quant_args, - block_size=self.block_size, + blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), ) - weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) + #weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) + weight = quantized_weight if is_module_offloaded(module): update_prefix_dict(self.layer, "weight", weight) @@ -335,7 +341,7 @@ def quantize_module(self, name, module, inp): update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") - logger.set_losses(losses) + metrics_logger.set_losses(losses) @contextlib.contextmanager def disable_hooks(self): diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 512741888..2f0d3120f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -7,10 +7,9 @@ from copy import copy from loguru import logger from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering, fake_quantize -from llmcompressor.utils.metric_logging import ( - get_GPU_memory_usage, - get_layer_size_bytes, -) +from compressed_tensors.quantization.observers import MovingAverageMinMaxObserver +from llmcompressor.pytorch.utils.helpers import tensor_sparsity +from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: @@ -40,6 +39,10 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: return H +def compute_scale_zeropoint(W: torch.Tensor, quant_args: QuantizationArgs) -> Tuple[torch.Tensor, torch.Tensor]: + return MovingAverageMinMaxObserver(quant_args)(W) + + def quantize_weight( weight: torch.Tensor, inp: torch.Tensor, @@ -52,7 +55,6 @@ def quantize_weight( actorder = quant_args.actorder final_shape = weight.shape final_dtype = weight.dtype - num_columns = weight.shape[1] W = weight.data.clone() H = compute_hessian(inp, module_class, device=weight.device) @@ -63,8 +65,7 @@ def quantize_weight( elif module_class == transformers.Conv1D: W.transpose_(0, 1) W = W.to(dtype=torch.float32) - - tick = time.time() + num_columns = W.shape[0] if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index @@ -93,127 +94,128 @@ def quantize_weight( else: scale, zero_point = compute_scale_zeropoint(W, quant_args) - # sparsity mask - sparsity = tensor_sparsity(W) - preserve_zeros = sparsity >= SPARSITY_THRESHOLD - W_nz_mask = ( - (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() - if preserve_zeros - else None - ) - - losses = torch.zeros(num_columns, device=weight.device) - - # mask dead hessian values - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - W[:, dead] = 0 - - # compute inverse hessian in place to save memory - # TODO: check in place - Hinv = invert_hessian(H, percdamp) - - # See section 3.4 of https://arxiv.org/abs/2203.07259 - for i1 in range(0, num_columns, blocksize): - i2 = min(i1 + blocksize, num_columns) - count = i2 - i1 + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + + losses = torch.zeros(num_columns, device=weight.device) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # compute inverse hessian in place to save memory + # TODO: check in place + Hinv = invert_hessian(H, percdamp) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) + count = i2 - i1 + print((i1, i2, num_columns)) + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = w.clone() + + # quantize column + if strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + quant_args, + ) + elif strategy == QuantizationStrategy.CHANNEL: + q = fake_quantize( + q, + scale[:, 0], + zero_point[:, 0], + quant_args, + ) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + column_idx = i1 + i + group_index = g_idx[column_idx] + + # Since we're only applying quantization to a slice, this + # ends up being a channelwise application + altered_qargs = copy(quant_args) + altered_qargs.strategy = QuantizationStrategy.CHANNEL + q = fake_quantize( + q, + scale[:, group_index], + zero_point[:, group_index], + altered_qargs, + ) + else: + raise ValueError( + "Quantization strategy is not supported for GPTQ: " + f"{strategy}" + ) - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] + # propagate column error + Q1[:, i] = q + losses1[:, i] = (w - q) ** 2 / d**2 + err1 = (w - q) / d + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) if preserve_zeros: - W1_nz_mask = W_nz_mask[:, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - q = w.clone() - - # quantize column - if strategy == QuantizationStrategy.TENSOR: - q = fake_quantize( - q, - scale, - zero_point, - quant_args, - ) - elif strategy == QuantizationStrategy.CHANNEL: - q = fake_quantize( - q, - scale[:, 0], - zero_point[:, 0], - quant_args, - ) - elif strategy == QuantizationStrategy.GROUP: - # get the group index for the current column - column_idx = i1 + i - group_index = g_idx[column_idx] - - # Since we're only applying quantization to a slice, this - # ends up being a channelwise application - altered_qargs = copy(quant_args) - altered_qargs.strategy = QuantizationStrategy.CHANNEL - q = fake_quantize( - q, - scale[:, group_index], - zero_point[:, group_index], - altered_qargs, - ) - else: - raise ValueError( - "Quantization strategy is not supported for GPTQ: " - f"{strategy}" - ) - - # propagate column error - Q1[:, i] = q - losses1[:, i] = (w - q) ** 2 / d**2 - - err1 = (w - q) / d - w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - if preserve_zeros: - W1[:, i:] -= w1_err * W1_nz_mask[:, i:] - else: - W1[:, i:] -= w1_err - Err1[:, i] = err1 - - # propagate block error - W[:, i1:i2] = Q1 - losses += torch.sum(losses1, 1) / 2 - - w_err = Err1.matmul(Hinv[i1:i2, i2:]) - if preserve_zeros: - W[:, i2:] -= w_err * W_nz_mask[:, i2:] + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] else: - W[:, i2:] -= w_err + W1[:, i:] -= w1_err + Err1[:, i] = err1 - has_gidx = False - if strategy == QuantizationStrategy.GROUP: - if actorder == ActivationOrdering.WEIGHT: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] + # propagate block error + W[:, i1:i2] = Q1 + losses += torch.sum(losses1, 1) / 2 - elif actorder == ActivationOrdering.GROUP: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - g_idx = g_idx[invperm] + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err - # only save g_idx if mapping is not identity - has_gidx = True + has_gidx = False + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] - if not has_gidx: - g_idx = None + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] - if module_class == transformers.Conv1D: - W.transpose_(0, 1) - W = W.reshape(final_shape).to(final_dtype) + # only save g_idx if mapping is not identity + has_gidx = True + + if not has_gidx: + g_idx = None + + if module_class == transformers.Conv1D: + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) - return losses, W, scale, zero_point, g_idx + return losses, W, scale, zero_point, g_idx def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index 413f5eaca..6ebb1dc7a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -6,7 +6,7 @@ from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes -__all__ = ["get_output_error", "gptq_hook"] +__all__ = ["get_output_error", "gptq_hook", "MetricsLogger"] def get_output_error( @@ -65,7 +65,7 @@ def wrapped(self, *args, **kwargs): return wrapped -class LogMetrics: +class MetricsLogger: def __init__(self, module: torch.nn.Module): self.module = module self.start_tick = None @@ -74,8 +74,9 @@ def __init__(self, module: torch.nn.Module): def set_losses(self, losses: torch.Tensor): self.losses = losses - def __enter__(self): + def __enter__(self) -> "MetricsLogger": self.start_tick = time.time() + return self def __exit__(self, _exc_type, _exc_val, _exc_tb): """ From 7f49ab40c245bea5a8350b479856dd5ced9fb573 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Oct 2024 23:50:33 +0000 Subject: [PATCH 016/285] runnable --- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 +- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 3d70b2d40..7b56f0e05 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -310,7 +310,7 @@ def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Ten if not self.true_sequential: # rerun with (now) quantized weights with self.disable_hooks(): - output = module(args, **kwargs) + output = module(*args, **kwargs) self._layer_index += 1 return output diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 2f0d3120f..ebe657ae4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -65,7 +65,8 @@ def quantize_weight( elif module_class == transformers.Conv1D: W.transpose_(0, 1) W = W.to(dtype=torch.float32) - num_columns = W.shape[0] + num_rows = W.shape[0] + num_columns = W.shape[1] if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index @@ -103,7 +104,7 @@ def quantize_weight( else None ) - losses = torch.zeros(num_columns, device=weight.device) + losses = torch.zeros(num_rows, device=weight.device) # mask dead hessian values dead = torch.diag(H) == 0 @@ -118,7 +119,6 @@ def quantize_weight( for i1 in range(0, num_columns, blocksize): i2 = min(i1 + blocksize, num_columns) count = i2 - i1 - print((i1, i2, num_columns)) W1 = W[:, i1:i2].clone() Q1 = torch.zeros_like(W1) @@ -166,8 +166,7 @@ def quantize_weight( ) else: raise ValueError( - "Quantization strategy is not supported for GPTQ: " - f"{strategy}" + f"Quantization strategy is not supported for GPTQ: {strategy}" ) # propagate column error From 45390520b373774f6e2b2ac19065931192d826a9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 18 Oct 2024 23:24:56 +0000 Subject: [PATCH 017/285] add ability to pass model class to support non-traditional (vision) models --- .../transformers/compression/helpers.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index a505dc0af..0b9742bf7 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import psutil import torch @@ -172,6 +172,7 @@ def custom_offload_device_map( model_stub: str, max_memory_per_gpu: Union[str, int], num_gpus: int = 1, + model_cls: Type = AutoModelForCausalLM, **model_kwargs, ) -> Dict[Union[int, str], Union[int, str]]: """ @@ -182,6 +183,8 @@ def custom_offload_device_map( :param max_memory_per_gpu: Max memory to allocate on each GPU, as either a string such as "10GB" or an integer number of bytes :param num_gpus: number of gpus to utilize + :param model_cls: model class to use when initializing model structure, + default is AutoModelForCausalLM :param model_kwargs: keyword arguments to pass to model initializer :return: memory mapping for layers of model_stub to be passed to from_pretrained() """ @@ -190,7 +193,7 @@ def custom_offload_device_map( memory_limits["cpu"] = max_cpu_memory with init_empty_weights(): - dummy_model = AutoModelForCausalLM.from_pretrained(model_stub, **model_kwargs) + dummy_model = model_cls.from_pretrained(model_stub, **model_kwargs) device_map = infer_auto_device_map( dummy_model, max_memory=memory_limits, @@ -206,6 +209,7 @@ def calculate_offload_device_map( num_gpus: Optional[int] = None, gpu_ids: Optional[List[int]] = None, torch_dtype: torch.dtype = torch.float16, + model_cls: Type = AutoModelForCausalLM, **model_kwargs, ) -> Dict[Union[int, str], Union[int, str]]: """ @@ -217,6 +221,8 @@ def calculate_offload_device_map( :param num_gpus: number of gpus to utilize, defaults to max available :param gpu_ids: list of gpu device ids to utilize, overrides num_gpus if provided :param torch_dtype: datatype in which model weights are to be loaded with + :param model_cls: model class to use when initializing model structure, + default is AutoModelForCausalLM :param model_kwargs: keyword arguments to pass to model initializer :return: memory mapping for layers of model_stub to be passed to from_pretrained() """ @@ -235,12 +241,11 @@ def calculate_offload_device_map( ) max_gpu_memory = { - device_id: torch.cuda.mem_get_info(device_id)[0] - for device_id in gpu_ids + device_id: torch.cuda.mem_get_info(device_id)[0] for device_id in gpu_ids } with init_empty_weights(): - dummy_model = AutoModelForCausalLM.from_pretrained( + dummy_model = model_cls.from_pretrained( model_stub, torch_dtype=torch_dtype, **model_kwargs ) From 9ae998b6c54723018e69650c8f18fd5823de4be2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 18 Oct 2024 23:26:37 +0000 Subject: [PATCH 018/285] update docstring --- src/llmcompressor/transformers/compression/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 0b9742bf7..b21db930f 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -113,8 +113,8 @@ def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str] def hessian_memory_requirements(model: torch.nn.Module) -> int: """ Determines the number of bytes needed to store Hessian data for a single - transformer layer in model. This is used for reserving memory for GPTQ/OBCQ - quantization + transformer layer in model. This is used for reserving memory for GPTQModifier + or SparseGPTModifier :param model: model to calculate requirements for :return: number of bytes required to reserve for GPTQ on a single layer From ac0d9266b1bb2af468dd8950646a0d1f1773ea41 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 20:30:07 +0000 Subject: [PATCH 019/285] batching --- .../modifiers/quantization/gptq/base.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7b56f0e05..1054dc436 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -20,6 +20,7 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier +from llmcompressor.transformers.finetune.data.data_helpers import format_calibration_data from llmcompressor.utils.fsdp.context import fix_fsdp_module_name from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain from llmcompressor.utils.pytorch.module import ( @@ -257,37 +258,28 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): - """ import torch.nn.functional as F - - accumulated_data = {} # Dictionary to accumulate samples per key - - def pad_tensor(tensor, max_len): - pad_size = max_len - tensor.size(1) # Calculate the padding size - return F.pad(tensor, (0, pad_size), value=0) # Pad on the right with zeros - - for batch in dataloader: - for key, value in batch.items(): - if key not in accumulated_data: - accumulated_data[key] = [] - accumulated_data[key].append(value) # Accumulate values for each key - - # Find maximum length for each key across all samples to ensure matching shapes - max_lengths = {} - for key, tensors in accumulated_data.items(): - max_lengths[key] = max([tensor.size(1) for tensor in tensors]) # Assuming the second dimension is the sequence length - - # Pad and concatenate for each key - concatenated_batch = { - key: torch.cat([pad_tensor(tensor, max_lengths[key]) for tensor in accumulated_data[key]], dim=0) - for key in accumulated_data - } - """ - - batch = next(iter(dataloader)) + from torch.nn.utils.rnn import pad_sequence + + dataset = dataloader.dataset + def collate_fn(batch): + # Extract input_ids and attention_mask from the batch + input_ids = [torch.tensor(item['input_ids']) for item in batch] + attention_masks = [torch.tensor(item['attention_mask']) for item in batch] + + # Pad sequences in the batch + padded_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0) + + return { + 'input_ids': padded_input_ids, + 'attention_mask': padded_attention_masks + } + dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=True, collate_fn=collate_fn) + data = next(iter(dataloader)) with DisableKVCache(model), DisableQuantization(model): - model(**batch) + model(**data) @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): @@ -362,7 +354,7 @@ def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): if recurse: for child_module in module.children(): - self.remove_hooks(child_module) + self.remove_gptq_hooks(child_module) def _build_quant_modifier(self): """ From 63049739f2e0fbb723dd9000e057a4000487c617 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 22:12:06 +0000 Subject: [PATCH 020/285] calibration forward context --- .../modifiers/quantization/gptq/base.py | 36 +++++++------ .../modifiers/utils/layer_compressor.py | 17 +++++++ src/llmcompressor/utils/helpers.py | 51 ++++++++++++------- 3 files changed, 71 insertions(+), 33 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 1054dc436..342e194cd 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch +from torch.nn.utils.rnn import pad_sequence import contextlib from functools import partial from compressed_tensors.quantization import ( @@ -20,9 +21,10 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier +from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.transformers.finetune.data.data_helpers import format_calibration_data from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.helpers import DisableKVCache, DisableQuantization, OnloadModule, getattr_chain +from llmcompressor.utils.helpers import calibration_forward_context, align_module, getattr_chain from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -258,28 +260,32 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): - import torch.nn.functional as F - from torch.nn.utils.rnn import pad_sequence - dataset = dataloader.dataset def collate_fn(batch): - # Extract input_ids and attention_mask from the batch - input_ids = [torch.tensor(item['input_ids']) for item in batch] - attention_masks = [torch.tensor(item['attention_mask']) for item in batch] + # extract input_ids and attention_mask from the batch + input_ids = [torch.tensor(item["input_ids"]) for item in batch] + attention_masks = [torch.tensor(item["attention_mask"]) for item in batch] - # Pad sequences in the batch + # pad sequences in the batch padded_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0) return { - 'input_ids': padded_input_ids, - 'attention_mask': padded_attention_masks + "input_ids": padded_input_ids, + "attention_mask": padded_attention_masks } - dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=True, collate_fn=collate_fn) - data = next(iter(dataloader)) - with DisableKVCache(model), DisableQuantization(model): - model(**data) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=len(dataset), + shuffle=True, + collate_fn=collate_fn, + pin_memory=True + ) + + calibration_data = next(iter(dataloader)) + with calibration_forward_context(model): + model(**calibration_data) @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): @@ -314,7 +320,7 @@ def quantize_module(self, name, module, args): quant_args = getattr_chain(module, "quantization_scheme.weights") # with onloaded weight - with OnloadModule(module), MetricsLogger(module) as metrics_logger: + with align_module(module), MetricsLogger(module) as metrics_logger: losses, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 3dd3caa7e..714d328df 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -20,6 +20,23 @@ __all__ = ["LayerCompressor"] +class LayerCompressorMixin: + def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): + return + for name, module in model.named_modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + pre_hook = partial(self.target_pre_forward, name) + post_hook = partial(self.target_post_forward, name) + module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook) + + if module in layers.values(): + pre_hook = partial(self.layer_pre_forward, name) + post_hook = partial(self.layer_post_forward, name) + module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) + module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) + + class LayerCompressor: """ Runs weight sparisification on a single layer using calibration data inputs. The diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index b46685110..db4846b7b 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -15,10 +15,11 @@ import sys import tarfile import warnings +import contextlib from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Optional from urllib.parse import urlparse import numpy @@ -1078,26 +1079,40 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): self.config.use_cache = self.restore_value -class DisableQuantization: - def __init__(self, model: torch.nn.Module): - self.model = model +@contextlib.contextmanager +def DisableQuantization(model: torch.nn.Module): + model.apply(disable_quantization) + yield + model.apply(enable_quantization) - def __enter__(self): - self.model.apply(disable_quantization) - def __exit__(self, _exc_type, _exc_val, _exc_tb): - self.model.apply(enable_quantization) +def calibration_forward_context(model: torch.nn.Module): + torch.eval() + with ( + torch.no_grad(), + DisableKVCache(model), + DisableQuantization(model), + ): + yield -class OnloadModule: - def __init__(self, module: torch.nn.Module): - self.module = module - self.is_module_offloaded = is_module_offloaded(self.module) - def __enter__(self): - if self.is_module_offloaded: - self.module._hf_hook.pre_forward(self.module) +@contextlib.contextmanager +def align_module(module: torch.nn.Module, device: Optional[torch.device] = None): + """ + Move an offloaded module's parameters to device or module execution device - def __exit__(self, _exc_type, _exc_val, _exc_tb): - if self.is_module_offloaded: - self.module._hf_hook.post_forward(self.module, None) \ No newline at end of file + :param module: module with parameters to align + :param device: optional device to move parameters to, if None is provided then + module execution device will be used + """ + if device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = device + + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, torch.tensor([])) + + if device is not None: + module._hf_hook.execution_device = original_device \ No newline at end of file From 868a480d9c3ae076dec8861bbcb03bc03b6b799b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 22:31:08 +0000 Subject: [PATCH 021/285] fix stuff --- examples/quantization_w4a16/llama3_example.py | 5 +++-- src/llmcompressor/modifiers/quantization/gptq/base.py | 3 +-- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 2 +- src/llmcompressor/utils/helpers.py | 3 ++- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 01d9dba8c..56aef6b7a 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -23,7 +23,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 512 // 4 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -44,10 +44,11 @@ def preprocess(example): # Tokenize inputs. +tokenizer.add_special_tokens({'pad_token': '[PAD]'}) def tokenize(sample): return tokenizer( sample["text"], - padding=False, + padding=True, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False, diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 342e194cd..77cf3c605 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -283,9 +283,8 @@ def collate_fn(batch): pin_memory=True ) - calibration_data = next(iter(dataloader)) with calibration_forward_context(model): - model(**calibration_data) + run_calibration_forward(model, dataloader, mask_padding=True) @gptq_hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 9003ff22d..20abaf376 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -102,7 +102,7 @@ def run_calibration_forward( # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass - torch.cuda.empty_cache() + #torch.cuda.empty_cache() return intermediates diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index db4846b7b..14c724320 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1086,8 +1086,9 @@ def DisableQuantization(model: torch.nn.Module): model.apply(enable_quantization) +@contextlib.contextmanager def calibration_forward_context(model: torch.nn.Module): - torch.eval() + model.eval() with ( torch.no_grad(), From 86c8a06dae722f289360737600d67714359ac797 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 22:44:45 +0000 Subject: [PATCH 022/285] wip --- examples/quantization_w4a16/llama3_example.py | 4 ++-- src/llmcompressor/modifiers/quantization/gptq/base.py | 1 + .../modifiers/quantization/gptq/utils/gptq_quantize.py | 9 +++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 56aef6b7a..fbb1f2e2c 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -24,7 +24,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 512 // 4 -MAX_SEQUENCE_LENGTH = 2048 +MAX_SEQUENCE_LENGTH = 2048 // 2 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) @@ -59,7 +59,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], percdamp=0.01) # Apply algorithms. oneshot( diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 77cf3c605..5cfd036a0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -283,6 +283,7 @@ def collate_fn(batch): pin_memory=True ) + breakpoint() with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index ebe657ae4..203dac5f0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -12,19 +12,24 @@ from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +GPTQ_PRECISION = torch.float32 + + def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: inp = inp.to(device=device) if len(inp.shape) == 2: inp = inp.unsqueeze(0) + breakpoint() if module_class in (torch.nn.Linear, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() nsamples = inp.shape[0] + breakpoint() - inp = inp.to(dtype=torch.float32) + inp = inp.to(dtype=GPTQ_PRECISION) inp = math.sqrt(2 / nsamples) * inp return inp.matmul(inp.t()) @@ -64,7 +69,7 @@ def quantize_weight( W = W.flatten(1) elif module_class == transformers.Conv1D: W.transpose_(0, 1) - W = W.to(dtype=torch.float32) + W = W.to(dtype=GPTQ_PRECISION) num_rows = W.shape[0] num_columns = W.shape[1] From 130517354f465253d1669a7ac70b2dbb9c85a905 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 21 Oct 2024 23:01:07 +0000 Subject: [PATCH 023/285] use hooks list --- examples/quantization_w4a16/llama3_example.py | 2 +- .../modifiers/quantization/gptq/base.py | 28 +++++++------------ .../quantization/gptq/utils/gptq_quantize.py | 7 ++--- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index fbb1f2e2c..2568c59ed 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -23,7 +23,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 // 4 +NUM_CALIBRATION_SAMPLES = 512 // 6 MAX_SEQUENCE_LENGTH = 2048 // 2 # Load dataset and preprocess. diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 5cfd036a0..216d9e6cb 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -3,6 +3,7 @@ import torch from torch.nn.utils.rnn import pad_sequence +from torch.utils.hooks import RemovableHandle import contextlib from functools import partial from compressed_tensors.quantization import ( @@ -125,6 +126,7 @@ class GPTQModifier(Modifier): _num_layers: int = 0 _hooks_disabled: bool = False quantization_modifier_: Optional[QuantizationModifier] = None + _hooks: List[RemovableHandle] = [] @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -241,7 +243,7 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self.quantization_modifier_: self.quantization_modifier_.finalize(state, **kwargs) - self.remove_gptq_hooks(state.model) + self.remove_gptq_hooks() return True @@ -250,14 +252,14 @@ def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Modu if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) post_hook = partial(self.target_post_forward, name) - module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook) + self._hooks.append(module.register_forward_pre_hook(pre_hook)) + self._hooks.append(module.register_forward_hook(post_hook)) if module in layers.values(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) - module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) + self._hooks.append(module.register_forward_pre_hook(pre_hook)) + self._hooks.append(module.register_forward_hook(post_hook, with_kwargs=True)) def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): dataset = dataloader.dataset @@ -283,7 +285,6 @@ def collate_fn(batch): pin_memory=True ) - breakpoint() with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) @@ -349,18 +350,9 @@ def disable_hooks(self): finally: self._hooks_disabled = False - def remove_gptq_hooks(self, module: torch.nn.Module, recurse: bool = True): - if hasattr(module, "_gptq_pre_hook"): - module._gptq_pre_hook.remove() - delattr(module, "_gptq_pre_hook") - - if hasattr(module, "_gptq_post_hook"): - module._gptq_post_hook.remove() - delattr(module, "_gptq_post_hook") - - if recurse: - for child_module in module.children(): - self.remove_gptq_hooks(child_module) + def remove_gptq_hooks(self): + for hook in self._hooks: + hook.remove() def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 203dac5f0..8e87f3ee0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -20,15 +20,14 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: if len(inp.shape) == 2: inp = inp.unsqueeze(0) - breakpoint() + nsamples = inp.shape[0] # note this is the number of dataset samples, not + # multiplied by the sequence length + if module_class in (torch.nn.Linear, transformers.Conv1D): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() - nsamples = inp.shape[0] - breakpoint() - inp = inp.to(dtype=GPTQ_PRECISION) inp = math.sqrt(2 / nsamples) * inp return inp.matmul(inp.t()) From e6adc5a9a823cd5dcc615ad210a2c639bf09f7ce Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 18:56:01 +0000 Subject: [PATCH 024/285] layer compressor --- .../modifiers/quantization/gptq/base.py | 147 ++++++------------ .../quantization/gptq/utils/gptq_quantize.py | 32 ++-- .../quantization/gptq/utils/helpers.py | 9 +- .../modifiers/utils/layer_compressor.py | 124 +++++++++++++-- .../modifiers/utils/pytorch_helpers.py | 2 +- src/llmcompressor/utils/helpers.py | 13 +- 6 files changed, 187 insertions(+), 140 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 216d9e6cb..1b1e56f23 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,43 +1,35 @@ -import gc -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch -from torch.nn.utils.rnn import pad_sequence -from torch.utils.hooks import RemovableHandle -import contextlib -from functools import partial from compressed_tensors.quantization import ( QuantizationScheme, freeze_module_quantization, ) +from compressed_tensors.utils import ( + is_module_offloaded, + update_parameter_data, + update_prefix_dict, +) from loguru import logger from pydantic import Field, field_validator +from torch.nn.utils.rnn import pad_sequence +from torch.utils.hooks import RemovableHandle from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory -from llmcompressor.modifiers.quantization.gptq.utils import ( - get_output_error, - gptq_hook +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( + quantize_weight, ) -from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier +from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.transformers.finetune.data.data_helpers import format_calibration_data -from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.helpers import calibration_forward_context, align_module, getattr_chain -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - qat_active, -) - -from compressed_tensors.utils import ( - get_offloaded_device, - is_module_offloaded, - update_parameter_data, - update_prefix_dict, +from llmcompressor.utils.helpers import ( + align_module, + calibration_forward_context, + getattr_chain, ) +from llmcompressor.utils.pytorch.module import qat_active __all__ = ["GPTQModifier"] @@ -138,13 +130,16 @@ def validate_sequential_update(cls, value: bool) -> bool: ) return value - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._layer_index = 0 self._num_layers = 0 self.quantization_modifier_ = None + self.layer_compressor = LayerCompressor( + self.quantize_module, self.true_sequential + ) def on_initialize_structure(self, state: State, **kwargs): """ @@ -212,18 +207,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - # find layers (used for printing even if true_sequential=True) - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - if self.sequential_targets is None: - self.sequential_targets = get_no_split_params(state.model) - layers = get_layers(self.sequential_targets, state.model) - self._num_layers = len(layers) - # add hooks to targets and layers # after lifecycle refactor, move this to pre_batch - self.register_hooks(state.model, layers) + self.layer_compressor.register_hooks(state.model, self.sequential_targets) # apply calibration and trigger hooks (hooks are self removing) self.calibration_forward(state.model, state.data.calib) @@ -233,7 +219,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: state.model.apply(freeze_module_quantization) return True - + def on_finalize(self, state: "State", **kwargs) -> bool: """ disable the quantization observers used by the OBCQ algorithm @@ -243,81 +229,50 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self.quantization_modifier_: self.quantization_modifier_.finalize(state, **kwargs) - self.remove_gptq_hooks() + self.layer_compressor.remove_hooks() return True - - def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): - for name, module in model.named_modules(): - if getattr_chain(module, "quantization_scheme.weights", None) is not None: - pre_hook = partial(self.target_pre_forward, name) - post_hook = partial(self.target_post_forward, name) - self._hooks.append(module.register_forward_pre_hook(pre_hook)) - self._hooks.append(module.register_forward_hook(post_hook)) - - if module in layers.values(): - pre_hook = partial(self.layer_pre_forward, name) - post_hook = partial(self.layer_post_forward, name) - self._hooks.append(module.register_forward_pre_hook(pre_hook)) - self._hooks.append(module.register_forward_hook(post_hook, with_kwargs=True)) - - def calibration_forward(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + + def calibration_forward( + self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader + ): dataset = dataloader.dataset + def collate_fn(batch): # extract input_ids and attention_mask from the batch input_ids = [torch.tensor(item["input_ids"]) for item in batch] attention_masks = [torch.tensor(item["attention_mask"]) for item in batch] - + # pad sequences in the batch - padded_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) - padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0) + padded_input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=0 + ) + padded_attention_masks = pad_sequence( + attention_masks, batch_first=True, padding_value=0 + ) return { "input_ids": padded_input_ids, - "attention_mask": padded_attention_masks + "attention_mask": padded_attention_masks, } - + dataloader = torch.utils.data.DataLoader( dataset, batch_size=len(dataset), shuffle=True, collate_fn=collate_fn, - pin_memory=True + pin_memory=True, ) - + with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) - @gptq_hook - def target_pre_forward(self, name: str, module: torch.nn.Module, args): - if self.true_sequential: - # compress first so output is from quantized weights - self.quantize_module(name, module, args) - - @gptq_hook - def target_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any): - if not self.true_sequential: - # compress after so output is from unquantized weights - self.quantize_module(name, module, args) - - @gptq_hook - def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): - logger.info(f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====") - - @gptq_hook - def layer_post_forward(self, name: str, module: torch.nn.Module, args: torch.Tensor, kwargs: Dict[str, Any], output: Any): - if not self.true_sequential: - # rerun with (now) quantized weights - with self.disable_hooks(): - output = module(*args, **kwargs) - - self._layer_index += 1 - return output - def quantize_module(self, name, module, args): logger.info(f"Compressing {name}...") - inp = args[0] # Assume that first argument is input (true for most Module types) + inp = args[ + 0 + ] # Assume that first argument is input (true for most Module types) quant_args = getattr_chain(module, "quantization_scheme.weights") # with onloaded weight @@ -330,10 +285,10 @@ def quantize_module(self, name, module, args): percdamp=self.dampening_frac, module_class=type(module), ) - - #weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) + + # weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) weight = quantized_weight - + if is_module_offloaded(module): update_prefix_dict(self.layer, "weight", weight) update_parameter_data(module, scale, "weight_scale") @@ -341,18 +296,6 @@ def quantize_module(self, name, module, args): update_parameter_data(module, g_idx, "weight_g_idx") metrics_logger.set_losses(losses) - - @contextlib.contextmanager - def disable_hooks(self): - try: - self._hooks_disabled = True - yield - finally: - self._hooks_disabled = False - - def remove_gptq_hooks(self): - for hook in self._hooks: - hook.remove() def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 8e87f3ee0..a94a8bf69 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,16 +1,19 @@ +import math +from copy import copy from typing import Tuple, Union -import time -import math import torch import transformers -from copy import copy -from loguru import logger -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, ActivationOrdering, fake_quantize +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationArgs, + QuantizationStrategy, + fake_quantize, +) from compressed_tensors.quantization.observers import MovingAverageMinMaxObserver -from llmcompressor.pytorch.utils.helpers import tensor_sparsity -from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +from llmcompressor.pytorch.utils.helpers import tensor_sparsity GPTQ_PRECISION = torch.float32 @@ -21,7 +24,7 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: inp = inp.unsqueeze(0) nsamples = inp.shape[0] # note this is the number of dataset samples, not - # multiplied by the sequence length + # multiplied by the sequence length if module_class in (torch.nn.Linear, transformers.Conv1D): if len(inp.shape) == 3: @@ -43,7 +46,9 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: return H -def compute_scale_zeropoint(W: torch.Tensor, quant_args: QuantizationArgs) -> Tuple[torch.Tensor, torch.Tensor]: +def compute_scale_zeropoint( + W: torch.Tensor, quant_args: QuantizationArgs +) -> Tuple[torch.Tensor, torch.Tensor]: return MovingAverageMinMaxObserver(quant_args)(W) @@ -53,14 +58,16 @@ def quantize_weight( quant_args: QuantizationArgs, blocksize: int = 128, percdamp: float = 0.01, - module_class = torch.nn.Linear, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + module_class=torch.nn.Linear, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor +]: strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape final_dtype = weight.dtype W = weight.data.clone() - + H = compute_hessian(inp, module_class, device=weight.device) # standardize shape and dtype @@ -220,6 +227,7 @@ def quantize_weight( return losses, W, scale, zero_point, g_idx + def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index 6ebb1dc7a..fceb7fd75 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -1,10 +1,13 @@ +import time from typing import Any, Iterable, List, Tuple, Union -import time import torch from loguru import logger -from llmcompressor.utils.metric_logging import get_GPU_memory_usage, get_layer_size_bytes +from llmcompressor.utils.metric_logging import ( + get_GPU_memory_usage, + get_layer_size_bytes, +) __all__ = ["get_output_error", "gptq_hook", "MetricsLogger"] @@ -59,7 +62,7 @@ def gptq_hook(func): def wrapped(self, *args, **kwargs): if self._hooks_disabled: return - + func(self, *args, **kwargs) return wrapped diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 714d328df..b6fe73ce6 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -1,12 +1,14 @@ +import contextlib import operator -from typing import Dict, Tuple +from functools import partial +from typing import Any, Callable, Dict, List, Tuple, Union import torch from compressed_tensors import get_execution_device from loguru import logger -from torch.nn import Module from tqdm import tqdm +from llmcompressor.modifiers.quantization.gptq.utils.helpers import get_output_error from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.pytorch.utils import tensors_to_device @@ -14,27 +16,123 @@ fix_fsdp_module_name, summon_full_params_context, ) -from llmcompressor.utils.pytorch import set_layer -from llmcompressor.utils.pytorch.module import get_prunable_layers +from llmcompressor.utils.helpers import getattr_chain +from llmcompressor.utils.pytorch.module import ( + get_layers, + get_no_split_params, + get_prunable_layers, + set_layer, +) __all__ = ["LayerCompressor"] -class LayerCompressorMixin: - def register_hooks(self, model: torch.nn.Module, layers: Dict[str, torch.nn.Module]): - return +class HooksMixin: + def __init__(self): + self.hooks = [] + self.hooks_disabled = False + + @classmethod + def hook(func): + def wrapped(self, *args, **kwargs): + if self.hooks_disabled: + return + + func(self, *args, **kwargs) + + return wrapped + + @contextlib.contextmanager + def disable_hooks(self): + try: + self._hooks_disabled = True + yield + finally: + self._hooks_disabled = False + + def remove_hooks(self): + for hook in self.hooks: + hook.remove() + + +class SequentialLayerCompressor(HooksMixin): + def __init__( + self, + compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], Any], + true_sequential: bool = True, + ): + self.compress_fn = compress_fn + self.true_sequential = true_sequential + + self._layer_index = 0 + self._num_layers = 0 + + def register_hooks( + self, model: torch.nn.Module, sequential_targets: Union[str, List[str], None] + ): + # find layers (used for printing even if true_sequential=True) + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + if self.sequential_targets is None: + self.sequential_targets = get_no_split_params(model) + layers = get_layers(sequential_targets, model) + self._num_layers = len(layers) + for name, module in model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) post_hook = partial(self.target_post_forward, name) - module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook) + self._hooks.append(module.register_forward_pre_hook(pre_hook)) + self._hooks.append(module.register_forward_hook(post_hook)) if module in layers.values(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) - module._gptq_pre_hook = module.register_forward_pre_hook(pre_hook) - module._gptq_post_hook = module.register_forward_hook(post_hook, with_kwargs=True) + self._hooks.append(module.register_forward_pre_hook(pre_hook)) + self._hooks.append( + module.register_forward_hook(post_hook, with_kwargs=True) + ) + + @HooksMixin.hook + def target_pre_forward(self, name: str, module: torch.nn.Module, args): + if self.true_sequential: + # compress first so output is from quantized weights + self.compress_fn(name, module, args) + + @HooksMixin.hook + def target_post_forward( + self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any + ): + if not self.true_sequential: + # compress after so output is from unquantized weights + self.compress_fn(name, module, args) + + @HooksMixin.hook + def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): + logger.info( + f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" + ) + + @HooksMixin.hook + def layer_post_forward( + self, + name: str, + module: torch.nn.Module, + args: torch.Tensor, + kwargs: Dict[str, Any], + output: Any, + ): + if not self.true_sequential: + # rerun with (now) compressed weights + with self.disable_hooks(): + compressed_output = module(*args, **kwargs) + + error = get_output_error(output, compressed_output) + logger.info(f"Mean output error from quantization: {error:.3f}") + + self._layer_index += 1 + return output class LayerCompressor: @@ -62,8 +160,8 @@ class LayerCompressor: def __init__( self, module_compressor_class: ModuleCompressionWrapper, - model: Module, - layer: Module, + model: torch.nn.Module, + layer: torch.nn.Module, layer_index: int, name: str, args: Dict, diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 20abaf376..c2f52a1cf 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -102,7 +102,7 @@ def run_calibration_forward( # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass - #torch.cuda.empty_cache() + # torch.cuda.empty_cache() return intermediates diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 14c724320..5891ab182 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -4,6 +4,7 @@ """ import ast +import contextlib import errno import fnmatch import glob @@ -15,23 +16,17 @@ import sys import tarfile import warnings -import contextlib from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import urlparse import numpy import torch +from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger -from compressed_tensors.quantization import ( - disable_quantization, - enable_quantization, -) -from compressed_tensors import is_module_offloaded - __all__ = [ "ALL_TOKEN", "ALL_PRUNABLE_TOKEN", @@ -1116,4 +1111,4 @@ def align_module(module: torch.nn.Module, device: Optional[torch.device] = None) module._hf_hook.post_forward(module, torch.tensor([])) if device is not None: - module._hf_hook.execution_device = original_device \ No newline at end of file + module._hf_hook.execution_device = original_device From f65f8322633ec79de04f0cbad4e6ea763f21751e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 19:12:48 +0000 Subject: [PATCH 025/285] style --- .../modifiers/quantization/gptq/base.py | 5 ++-- src/llmcompressor/utils/helpers.py | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 1b1e56f23..7e0d968dd 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -270,9 +270,8 @@ def collate_fn(batch): def quantize_module(self, name, module, args): logger.info(f"Compressing {name}...") - inp = args[ - 0 - ] # Assume that first argument is input (true for most Module types) + # Assume that first argument is input (true for most supported Module types) + inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") # with onloaded weight diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 5891ab182..03abf18be 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -24,6 +24,7 @@ import numpy import torch +from compressed_tensors import is_module_offloaded from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger @@ -1102,13 +1103,25 @@ def align_module(module: torch.nn.Module, device: Optional[torch.device] = None) :param device: optional device to move parameters to, if None is provided then module execution device will be used """ - if device is not None: - original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = device + if is_module_offloaded(module): + if device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = device - module._hf_hook.pre_forward(module) - yield - module._hf_hook.post_forward(module, torch.tensor([])) + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, torch.tensor([])) + + if device is not None: + module._hf_hook.execution_device = original_device + + elif device is not None: + devices = {} + for name, param in module.named_parameters(): + devices[name] = param.device + setattr(module, name, param.to(device)) + + yield - if device is not None: - module._hf_hook.execution_device = original_device + for name, param_device in module.named_parameters: + setattr(module, name, param.to(param_device)) From 1e225692d9ccb7a3769d7e1dc724770f69cb7d92 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 19:55:15 +0000 Subject: [PATCH 026/285] use layer compressor --- examples/quantization_w4a16/llama3_example.py | 12 +++-- .../modifiers/quantization/gptq/base.py | 38 ++++++-------- .../quantization/gptq/utils/helpers.py | 52 +++++-------------- .../modifiers/utils/layer_compressor.py | 48 +++++++++-------- src/llmcompressor/utils/helpers.py | 3 ++ .../pruning/sparsegpt/test_pytorch.py | 16 +++--- 6 files changed, 73 insertions(+), 96 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 2568c59ed..96adcbfdc 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -6,8 +6,8 @@ from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot # Select model and load it. -#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -#MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +# MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = SparseAutoModelForCausalLM.from_pretrained( @@ -44,7 +44,9 @@ def preprocess(example): # Tokenize inputs. -tokenizer.add_special_tokens({'pad_token': '[PAD]'}) +tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + def tokenize(sample): return tokenizer( sample["text"], @@ -59,7 +61,9 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], percdamp=0.01) +recipe = GPTQModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head"], percdamp=0.01 +) # Apply algorithms. oneshot( diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7e0d968dd..44dcdf194 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -13,7 +13,6 @@ from loguru import logger from pydantic import Field, field_validator from torch.nn.utils.rnn import pad_sequence -from torch.utils.hooks import RemovableHandle from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory @@ -22,7 +21,7 @@ ) from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier -from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor +from llmcompressor.modifiers.utils.layer_compressor import SequentialLayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.helpers import ( align_module, @@ -72,6 +71,7 @@ class GPTQModifier(Modifier): :param sequential_update: Whether or not to update weights sequentially by layer, True saves on GPU memory, default is True + :param true_sequential: TODO :param targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass @@ -102,7 +102,7 @@ class GPTQModifier(Modifier): """ sequential_update: bool = True - true_sequential: bool = False + true_sequential: bool = True targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -114,11 +114,8 @@ class GPTQModifier(Modifier): num_calibration_steps: Optional[int] = None scheme: Optional[Union[str, Dict[str, Any]]] = None - _layer_index: int = 0 - _num_layers: int = 0 - _hooks_disabled: bool = False - quantization_modifier_: Optional[QuantizationModifier] = None - _hooks: List[RemovableHandle] = [] + _quantization_modifier: Optional[QuantizationModifier] = None + _layer_compressor: Optional[SequentialLayerCompressor] = None @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -134,10 +131,7 @@ def validate_sequential_update(cls, value: bool) -> bool: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._layer_index = 0 - self._num_layers = 0 - self.quantization_modifier_ = None - self.layer_compressor = LayerCompressor( + self._layer_compressor = SequentialLayerCompressor( self.quantize_module, self.true_sequential ) @@ -191,8 +185,8 @@ def on_initialize_structure(self, state: State, **kwargs): self._build_quant_modifier_from_dict(self.quantize) self.quantize = True - if self.quantization_modifier_: - self.quantization_modifier_.on_initialize_structure(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.on_initialize_structure(state, **kwargs) def on_initialize(self, state: "State", **kwargs) -> bool: """ @@ -202,14 +196,14 @@ def on_initialize(self, state: "State", **kwargs) -> bool: """ if not self.initialized_structure_: self.on_initialize_structure(state, **kwargs) - if self.quantization_modifier_: - self.quantization_modifier_.initialize(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") # add hooks to targets and layers # after lifecycle refactor, move this to pre_batch - self.layer_compressor.register_hooks(state.model, self.sequential_targets) + self._layer_compressor.register_hooks(state.model, self.sequential_targets) # apply calibration and trigger hooks (hooks are self removing) self.calibration_forward(state.model, state.data.calib) @@ -226,10 +220,10 @@ def on_finalize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ - if self.quantization_modifier_: - self.quantization_modifier_.finalize(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.finalize(state, **kwargs) - self.layer_compressor.remove_hooks() + self._layer_compressor.remove_hooks() return True @@ -301,7 +295,7 @@ def _build_quant_modifier(self): Build a quantization modifier based on the specified config_groups, ignore list, and num_calibration_steps. - :postcondition: self.quantization_modifier_ is set to the built + :postcondition: self._quantization_modifier is set to the built quantization modifier """ @@ -327,7 +321,7 @@ def _build_quant_modifier(self): def _build_quant_modifier_from_dict(self, quant_config): modifier_type = list(quant_config.keys())[0] modifier_args = quant_config[modifier_type] - self.quantization_modifier_ = ModifierFactory.create( + self._quantization_modifier = ModifierFactory.create( modifier_type, allow_registered=True, allow_experimental=True, diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py index fceb7fd75..a369e0d4c 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py @@ -13,51 +13,23 @@ def get_output_error( - unquantized: List[Tuple[Union[Iterable, torch.Tensor], Any]], - quantized: List[Tuple[Union[Iterable, torch.Tensor], Any]], + uncompressed: Tuple[torch.Tensor, ...], + compressed: Tuple[torch.Tensor, ...], ) -> torch.Tensor: """ - Calculate mean l1 loss between weight-unquantized outputs and weight-quantized - outputs + Calculate mean absolute error between weight-uncompressed outputs and + weight-compressed outputs - :param unquantized: unquantized-weight outputs - :param quantized: quantized-weight outputs - :return: mean l1 loss between outputs + :param uncompressed: uncompressed-weight outputs + :param compressed: compressed-weight outputs + :return: mean absolute error between outputs """ - unquantized_outputs = sum( - [ - [output for output in outputs] - if isinstance(outputs, Iterable) - else [outputs] - for outputs, _ in unquantized - ], - start=[], - ) - - quantized_outputs = sum( - [ - [output for output in outputs] - if isinstance(outputs, Iterable) - else [outputs] - for outputs, _ in quantized - ], - start=[], - ) - - if len(unquantized_outputs) != len(quantized_outputs): - raise ValueError( - "Number of samples of weight-unquantized and weight-quantized " - "outputs differs" - ) - - return sum( - [ - torch.nn.functional.l1_loss(unq, q) - for unq, q in zip(unquantized_outputs, quantized_outputs) - ] - ) / len(unquantized_outputs) - + # assume first output is the the relevant output (true for most Modules) + uncompressed = uncompressed[0] + compressed = compressed[0] + return torch.mean(torch.abs(uncompressed - compressed)) + def gptq_hook(func): def wrapped(self, *args, **kwargs): if self._hooks_disabled: diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index b6fe73ce6..a2bdf0582 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -28,30 +28,35 @@ class HooksMixin: - def __init__(self): - self.hooks = [] - self.hooks_disabled = False + HOOKS_DISABLED: bool = False @classmethod - def hook(func): - def wrapped(self, *args, **kwargs): - if self.hooks_disabled: + def hook(cls, func): + def wrapped(*args, **kwargs): + if cls.HOOKS_DISABLED: return - func(self, *args, **kwargs) + func(*args, **kwargs) return wrapped + @classmethod @contextlib.contextmanager - def disable_hooks(self): + def disable_hooks(cls): try: - self._hooks_disabled = True + cls.HOOKS_DISABLED = True yield finally: - self._hooks_disabled = False + cls.HOOKS_DISABLED = False + + def __init__(self): + self._hooks = [] + + def register_hook(self, handle: torch.utils.hooks.RemovableHandle): + self._hooks.append(handle) def remove_hooks(self): - for hook in self.hooks: + for hook in self._hooks: hook.remove() @@ -61,6 +66,7 @@ def __init__( compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], Any], true_sequential: bool = True, ): + HooksMixin.__init__(self) self.compress_fn = compress_fn self.true_sequential = true_sequential @@ -74,8 +80,8 @@ def register_hooks( # if no targets are provided, default to the modules that shouldn't be # split by FSDP. For Transformers models this is equivalent to the # decoder layers (ie LlamaDecoderLayer) - if self.sequential_targets is None: - self.sequential_targets = get_no_split_params(model) + if sequential_targets is None: + sequential_targets = get_no_split_params(model) layers = get_layers(sequential_targets, model) self._num_layers = len(layers) @@ -83,16 +89,14 @@ def register_hooks( if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) post_hook = partial(self.target_post_forward, name) - self._hooks.append(module.register_forward_pre_hook(pre_hook)) - self._hooks.append(module.register_forward_hook(post_hook)) + self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_hook(post_hook)) - if module in layers.values(): + if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) - self._hooks.append(module.register_forward_pre_hook(pre_hook)) - self._hooks.append( - module.register_forward_hook(post_hook, with_kwargs=True) - ) + self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_hook(post_hook, with_kwargs=True)) @HooksMixin.hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): @@ -121,11 +125,11 @@ def layer_post_forward( module: torch.nn.Module, args: torch.Tensor, kwargs: Dict[str, Any], - output: Any, + output: Tuple[torch.Tensor, ...], ): if not self.true_sequential: # rerun with (now) compressed weights - with self.disable_hooks(): + with HooksMixin.disable_hooks(): compressed_output = module(*args, **kwargs) error = get_output_error(output, compressed_output) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 03abf18be..c414d134b 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1125,3 +1125,6 @@ def align_module(module: torch.nn.Module, device: Optional[torch.device] = None) for name, param_device in module.named_parameters: setattr(module, name, param.to(param_device)) + + else: + yield diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 5421af4cf..3cdc25038 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -75,15 +75,15 @@ def test_create_default_quant_modifier(self): kwargs = dict(block_size=128) modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - modifier.quantization_modifier_.create_init_config() + assert isinstance(modifier._quantization_modifier, QuantizationModifier) + modifier._quantization_modifier.create_init_config() default_config_group_name = "group_0" - should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ + should_be_default_quant_scheme = modifier._quantization_modifier.config_groups[ default_config_group_name ] assert should_be_default_quant_scheme.input_activations is None @@ -113,7 +113,7 @@ def test_set_quant_if_modifer_already_exists(self): kwargs = dict(block_size=128) modifier = GPTQModifier(**kwargs) - assert not modifier.quantization_modifier_ + assert not modifier._quantization_modifier modifier.on_initialize_structure(testing_harness.get_state()) # since quantization modifier is already applied, quantization must be set in @@ -150,14 +150,14 @@ def test_set_quant_in_gptq(self): kwargs = dict(block_size=128, quantize=self.quant_config) modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) + self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) - dict_scheme = dict(modifier.quantization_modifier_.config_groups) + dict_scheme = dict(modifier._quantization_modifier.config_groups) self._check_config( dict(dict_scheme["config_group_0"].weights), self.quant_kwargs["config_groups"]["config_group_0"]["weights"], From 9324695131e30f3755a6b2d64be6c6a3d0731ac5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 21:16:40 +0000 Subject: [PATCH 027/285] replicate dtypes --- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index a94a8bf69..a193ae817 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -49,7 +49,12 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: def compute_scale_zeropoint( W: torch.Tensor, quant_args: QuantizationArgs ) -> Tuple[torch.Tensor, torch.Tensor]: - return MovingAverageMinMaxObserver(quant_args)(W) + # TODO: revisit after observers refactor + + scale, zero_point = quant_args.get_observer()(W, g_idx=None) + scale = scale.to(dtype=W.dtype) + zero_point = zero_point.to(dtype=quant_args.pytorch_dtype()) + return scale, zero_point def quantize_weight( From eef4fb6f666b688f1f3b936a6efe8c170a644908 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 21:21:30 +0000 Subject: [PATCH 028/285] write weight changes --- src/llmcompressor/modifiers/quantization/gptq/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 44dcdf194..7e0617bdb 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -102,7 +102,7 @@ class GPTQModifier(Modifier): """ sequential_update: bool = True - true_sequential: bool = True + true_sequential: bool = False targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -284,6 +284,7 @@ def quantize_module(self, name, module, args): if is_module_offloaded(module): update_prefix_dict(self.layer, "weight", weight) + update_parameter_data(module, weight, "weight") update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") From 485813a6d73a4b0749d2fc44f839243423601269 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 22:04:39 +0000 Subject: [PATCH 029/285] revert example --- examples/quantization_w4a16/llama3_example.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 96adcbfdc..939991ab6 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,4 +1,3 @@ -import torch from datasets import load_dataset from transformers import AutoTokenizer @@ -6,9 +5,7 @@ from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot # Select model and load it. -# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -# MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" -MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, @@ -23,8 +20,8 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 // 6 -MAX_SEQUENCE_LENGTH = 2048 // 2 +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) @@ -44,13 +41,10 @@ def preprocess(example): # Tokenize inputs. -tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - def tokenize(sample): return tokenizer( sample["text"], - padding=True, + padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False, @@ -61,9 +55,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier( - targets="Linear", scheme="W4A16", ignore=["lm_head"], percdamp=0.01 -) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) # Apply algorithms. oneshot( From 60061551514efc3f15f1fbd6e38e033703252bb5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 22:33:16 +0000 Subject: [PATCH 030/285] organization --- .../modifiers/quantization/gptq/base.py | 25 +++--- .../quantization/gptq/utils/gptq_quantize.py | 8 +- .../quantization/gptq/utils/helpers.py | 84 ------------------- .../modifiers/utils/layer_compressor.py | 24 ++++-- src/llmcompressor/utils/metric_logging.py | 54 +++++++++++- 5 files changed, 82 insertions(+), 113 deletions(-) delete mode 100644 src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7e0617bdb..2917c85c4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from compressed_tensors.quantization import ( @@ -19,7 +19,6 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( quantize_weight, ) -from llmcompressor.modifiers.quantization.gptq.utils.helpers import MetricsLogger from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.layer_compressor import SequentialLayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -261,16 +260,17 @@ def collate_fn(batch): with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) - def quantize_module(self, name, module, args): - logger.info(f"Compressing {name}...") + def quantize_module( + self, name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...] + ) -> float: + logger.info(f"Quantizing {name}...") - # Assume that first argument is input (true for most supported Module types) + # Assume that first argument is the input inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - # with onloaded weight - with align_module(module), MetricsLogger(module) as metrics_logger: - losses, quantized_weight, scale, zero_point, g_idx = quantize_weight( + with align_module(module): + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, quant_args, @@ -279,17 +279,16 @@ def quantize_module(self, name, module, args): module_class=type(module), ) - # weight = torch.lerp(module.weight.data, quantized_weight, self.alpha) - weight = quantized_weight + # FUTURE: Implement learning rate modification to weight update if is_module_offloaded(module): - update_prefix_dict(self.layer, "weight", weight) - update_parameter_data(module, weight, "weight") + update_prefix_dict(self.layer, "weight", quantized_weight) + update_parameter_data(module, quantized_weight, "weight") update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") - metrics_logger.set_losses(losses) + return loss def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index a193ae817..b21956cee 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -10,7 +10,6 @@ QuantizationStrategy, fake_quantize, ) -from compressed_tensors.quantization.observers import MovingAverageMinMaxObserver from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD from llmcompressor.pytorch.utils.helpers import tensor_sparsity @@ -64,9 +63,7 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class=torch.nn.Linear, -) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor -]: +) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape @@ -230,7 +227,8 @@ def quantize_weight( W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) - return losses, W, scale, zero_point, g_idx + loss = torch.sum(losses).item() + return loss, W, scale, zero_point, g_idx def _apply_activation_ordering( diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py deleted file mode 100644 index a369e0d4c..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ /dev/null @@ -1,84 +0,0 @@ -import time -from typing import Any, Iterable, List, Tuple, Union - -import torch -from loguru import logger - -from llmcompressor.utils.metric_logging import ( - get_GPU_memory_usage, - get_layer_size_bytes, -) - -__all__ = ["get_output_error", "gptq_hook", "MetricsLogger"] - - -def get_output_error( - uncompressed: Tuple[torch.Tensor, ...], - compressed: Tuple[torch.Tensor, ...], -) -> torch.Tensor: - """ - Calculate mean absolute error between weight-uncompressed outputs and - weight-compressed outputs - - :param uncompressed: uncompressed-weight outputs - :param compressed: compressed-weight outputs - :return: mean absolute error between outputs - """ - # assume first output is the the relevant output (true for most Modules) - uncompressed = uncompressed[0] - compressed = compressed[0] - - return torch.mean(torch.abs(uncompressed - compressed)) - -def gptq_hook(func): - def wrapped(self, *args, **kwargs): - if self._hooks_disabled: - return - - func(self, *args, **kwargs) - - return wrapped - - -class MetricsLogger: - def __init__(self, module: torch.nn.Module): - self.module = module - self.start_tick = None - self.losses = None - - def set_losses(self, losses: torch.Tensor): - self.losses = losses - - def __enter__(self) -> "MetricsLogger": - self.start_tick = time.time() - return self - - def __exit__(self, _exc_type, _exc_val, _exc_tb): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - - if self.start_tick is not None: - patch.log("METRIC", "time %.2f" % (time.time() - self.start_tick)) - if self.losses is not None: - patch.log("METRIC", "error %.2f" % torch.sum(self.losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - compressed_size = get_layer_size_bytes(self.module) - patch.log("METRIC", f"Compressed layer size: {compressed_size} MB") diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index a2bdf0582..5bb459372 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -8,7 +8,6 @@ from loguru import logger from tqdm import tqdm -from llmcompressor.modifiers.quantization.gptq.utils.helpers import get_output_error from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.pytorch.utils import tensors_to_device @@ -17,6 +16,7 @@ summon_full_params_context, ) from llmcompressor.utils.helpers import getattr_chain +from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -24,7 +24,7 @@ set_layer, ) -__all__ = ["LayerCompressor"] +__all__ = ["SequentialLayerCompressor", "LayerCompressor"] class HooksMixin: @@ -63,7 +63,7 @@ def remove_hooks(self): class SequentialLayerCompressor(HooksMixin): def __init__( self, - compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], Any], + compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], float], true_sequential: bool = True, ): HooksMixin.__init__(self) @@ -96,21 +96,27 @@ def register_hooks( pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) self.register_hook(module.register_forward_pre_hook(pre_hook)) - self.register_hook(module.register_forward_hook(post_hook, with_kwargs=True)) + self.register_hook( + module.register_forward_hook(post_hook, with_kwargs=True) + ) @HooksMixin.hook def target_pre_forward(self, name: str, module: torch.nn.Module, args): if self.true_sequential: - # compress first so output is from quantized weights - self.compress_fn(name, module, args) + # compress first so output is from compressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_fn(name, module, args) + comp_logger.set_loss(loss) @HooksMixin.hook def target_post_forward( self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any ): if not self.true_sequential: - # compress after so output is from unquantized weights - self.compress_fn(name, module, args) + # compress after so output is from uncompressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_fn(name, module, args) + comp_logger.set_loss(loss) @HooksMixin.hook def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): @@ -132,7 +138,7 @@ def layer_post_forward( with HooksMixin.disable_hooks(): compressed_output = module(*args, **kwargs) - error = get_output_error(output, compressed_output) + error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) logger.info(f"Mean output error from quantization: {error:.3f}") self._layer_index += 1 diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index d0b3bb11e..b23ba200a 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -1,7 +1,10 @@ +import time from typing import List, Tuple +import torch from loguru import logger -from torch.nn import Module + +__all__ = ["CompressionLogger"] def get_GPU_memory_usage() -> List[Tuple]: @@ -35,7 +38,7 @@ def get_GPU_memory_usage() -> List[Tuple]: return [] -def get_layer_size_bytes(module: Module) -> float: +def get_module_size_bytes(module: torch.nn.Module) -> float: param_size = 0 buffer_size = 0 @@ -49,3 +52,50 @@ def get_layer_size_bytes(module: Module) -> float: total_size_mb = total_size / (1024**2) # Convert bytes to MB return total_size_mb + + +class CompressionLogger: + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + + def __init__(self, module: torch.nn.Module): + self.module = module + self.start_tick = None + self.loss = None + + def set_loss(self, loss: float): + self.loss = loss + + def __enter__(self) -> "CompressionLogger": + self.start_tick = time.time() + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + stop_tick = time.time() + patch = logger.patch(lambda r: r.update(function="compress")) + + if self.start_tick is not None: + duration = stop_tick - self.start_tick + patch.log("METRIC", f"time {duration:.2f}") + if self.loss is not None: + patch.log("METRIC", f"error {self.loss:.2f}") + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + compressed_size = get_module_size_bytes(self.module) + patch.log("METRIC", f"Compressed module size: {compressed_size} MB") From c10d2ee3d9f82d68926fe1592f67a9dd2b798bbc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 22:50:40 +0000 Subject: [PATCH 031/285] add create_single_batch_dataloader --- .../modifiers/quantization/gptq/base.py | 33 +++-------------- .../quantization/gptq/utils/__init__.py | 1 - .../finetune/data/data_helpers.py | 35 +++++++++++++++++-- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 2917c85c4..1b03130c7 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -12,7 +12,6 @@ ) from loguru import logger from pydantic import Field, field_validator -from torch.nn.utils.rnn import pad_sequence from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory @@ -22,6 +21,9 @@ from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.layer_compressor import SequentialLayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.transformers.finetune.data.data_helpers import ( + create_single_batch_dataloader, +) from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, @@ -229,34 +231,7 @@ def on_finalize(self, state: "State", **kwargs) -> bool: def calibration_forward( self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader ): - dataset = dataloader.dataset - - def collate_fn(batch): - # extract input_ids and attention_mask from the batch - input_ids = [torch.tensor(item["input_ids"]) for item in batch] - attention_masks = [torch.tensor(item["attention_mask"]) for item in batch] - - # pad sequences in the batch - padded_input_ids = pad_sequence( - input_ids, batch_first=True, padding_value=0 - ) - padded_attention_masks = pad_sequence( - attention_masks, batch_first=True, padding_value=0 - ) - - return { - "input_ids": padded_input_ids, - "attention_mask": padded_attention_masks, - } - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=len(dataset), - shuffle=True, - collate_fn=collate_fn, - pin_memory=True, - ) - + dataloader = create_single_batch_dataloader(dataloader.dataset) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py index 5703ced46..ec39da973 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa from .gptq_quantize import * -from .helpers import * diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..933f64bd9 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,9 +1,11 @@ import logging import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Optional +import datasets import torch from datasets import Dataset, load_dataset +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator @@ -11,20 +13,49 @@ LABELS_MASK_VALUE = -100 __all__ = [ + "create_single_batch_dataloader", "format_calibration_data", "get_raw_dataset", "make_dataset_splits", "get_custom_datasets_from_path", + "LABELS_MASK_VALUE", ] +def create_single_batch_dataloader( + dataset: datasets.Dataset, +) -> torch.utils.data.DataLoader: + def pad_sequences(batch): + # extract input_ids and attention_mask from the batch + input_ids = [torch.tensor(item["input_ids"]) for item in batch] + masks = [torch.tensor(item["attention_mask"]) for item in batch] + + # while 0 is not necessarily the "correct" padding value, the padded + # input_ids are ignored according to the attention_mask + pad_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + pad_masks = pad_sequence(masks, batch_first=True, padding_value=0) + + return { + "input_ids": pad_input_ids, + "attention_mask": pad_masks, + } + + return torch.utils.data.DataLoader( + dataset, + batch_size=len(dataset), + shuffle=True, + collate_fn=pad_sequences, + pin_memory=True, + ) + + def format_calibration_data( tokenized_dataset: Dataset, num_calibration_samples: Optional[int] = None, do_shuffle: bool = True, collate_fn: Callable = default_data_collator, accelerator: Optional[Any] = None, -) -> List[torch.Tensor]: +) -> torch.utils.data.DataLoader: """ Creates a dataloader out of the calibration dataset split, trimming it to the desired number of calibration samples From 637119322b4cf2202b34a7f7385d5102fe5ff013 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 22:54:13 +0000 Subject: [PATCH 032/285] add back empty_cache until I can justify removing it --- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 +- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 1b03130c7..b59085dab 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -206,7 +206,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # after lifecycle refactor, move this to pre_batch self._layer_compressor.register_hooks(state.model, self.sequential_targets) - # apply calibration and trigger hooks (hooks are self removing) + # apply calibration and trigger hooks self.calibration_forward(state.model, state.data.calib) # freeze quantization diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index c2f52a1cf..9003ff22d 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -102,7 +102,7 @@ def run_calibration_forward( # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass - # torch.cuda.empty_cache() + torch.cuda.empty_cache() return intermediates From 92315a5197e9321e6d9bd3667b9505e9a527027c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:26:56 +0000 Subject: [PATCH 033/285] better type hinting, faster mask applying --- src/llmcompressor/modifiers/utils/layer_compressor.py | 11 ++++++----- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 5bb459372..833c8176d 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -101,7 +101,8 @@ def register_hooks( ) @HooksMixin.hook - def target_pre_forward(self, name: str, module: torch.nn.Module, args): + def target_pre_forward(self, name: str, module: torch.nn.Module, args: Tuple[Any, ...]): + breakpoint() if self.true_sequential: # compress first so output is from compressed weights with CompressionLogger(module) as comp_logger: @@ -110,7 +111,7 @@ def target_pre_forward(self, name: str, module: torch.nn.Module, args): @HooksMixin.hook def target_post_forward( - self, name: str, module: torch.nn.Module, args: torch.Tensor, _output: Any + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], _output: Tuple[Any, ...] ): if not self.true_sequential: # compress after so output is from uncompressed weights @@ -119,7 +120,7 @@ def target_post_forward( comp_logger.set_loss(loss) @HooksMixin.hook - def layer_pre_forward(self, name: str, module: torch.nn.Module, args: Any): + def layer_pre_forward(self, _name: str, _module: torch.nn.Module, _args: Any): logger.info( f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) @@ -129,9 +130,9 @@ def layer_post_forward( self, name: str, module: torch.nn.Module, - args: torch.Tensor, + args: Tuple[Any, ...], kwargs: Dict[str, Any], - output: Tuple[torch.Tensor, ...], + output: Tuple[Any, ...], ): if not self.true_sequential: # rerun with (now) compressed weights diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 9003ff22d..43b261d99 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -39,7 +39,7 @@ def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.T :param batch: batch to apply padding to if it exists :return: batch with padding zeroed out in the input_ids """ - batch["input_ids"] = batch["input_ids"] * batch["attention_mask"] + batch["input_ids"].masked_fill_(batch["attention_mask"] == 0, 0) return batch From 8a25c68438a487b423c8b872b3b42d7a30b36b4d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:37:59 +0000 Subject: [PATCH 034/285] remove breakpoint --- src/llmcompressor/modifiers/utils/layer_compressor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 833c8176d..8fd933a1d 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -102,7 +102,6 @@ def register_hooks( @HooksMixin.hook def target_pre_forward(self, name: str, module: torch.nn.Module, args: Tuple[Any, ...]): - breakpoint() if self.true_sequential: # compress first so output is from compressed weights with CompressionLogger(module) as comp_logger: From 6cd0d6cc1255fea96fd99599cc09b5e72c44bdb0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:41:10 +0000 Subject: [PATCH 035/285] apply style, add true_sequential docstring --- src/llmcompressor/modifiers/quantization/gptq/base.py | 7 +++++-- src/llmcompressor/modifiers/utils/layer_compressor.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 04260340c..0958602da 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -11,7 +12,6 @@ update_prefix_dict, ) from loguru import logger -import warnings from pydantic import Field, field_validator from llmcompressor.core import State @@ -72,7 +72,10 @@ class GPTQModifier(Modifier): :param sequential_update: Whether or not to update weights sequentially by layer. This option is depreciated and setting to False is no longer supported - :param true_sequential: TODO + :param true_sequential: Used to control the granularity of compression updates + through the forward pass. Set to True to use the weight-compressed outputs + of each module, set to False to use the weight-compressed outputs of each + layer (transformer block) :param targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 8fd933a1d..c2130068f 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -101,7 +101,9 @@ def register_hooks( ) @HooksMixin.hook - def target_pre_forward(self, name: str, module: torch.nn.Module, args: Tuple[Any, ...]): + def target_pre_forward( + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...] + ): if self.true_sequential: # compress first so output is from compressed weights with CompressionLogger(module) as comp_logger: @@ -110,7 +112,11 @@ def target_pre_forward(self, name: str, module: torch.nn.Module, args: Tuple[Any @HooksMixin.hook def target_post_forward( - self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], _output: Tuple[Any, ...] + self, + name: str, + module: torch.nn.Module, + args: Tuple[Any, ...], + _output: Tuple[Any, ...], ): if not self.true_sequential: # compress after so output is from uncompressed weights From 0e0c586c4773d5b9b2176d4d0688468361bbc94a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:42:25 +0000 Subject: [PATCH 036/285] update docstring --- src/llmcompressor/modifiers/quantization/gptq/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 0958602da..22404f0da 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,4 +1,3 @@ -import warnings from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -12,6 +11,7 @@ update_prefix_dict, ) from loguru import logger +import warnings from pydantic import Field, field_validator from llmcompressor.core import State @@ -53,6 +53,7 @@ class GPTQModifier(Modifier): | test_stage: | obcq_modifiers: | GPTQModifier: + | true_sequential: False | dampening_frac: 0.001 | block_size: 128 | config_groups: @@ -75,7 +76,7 @@ class GPTQModifier(Modifier): :param true_sequential: Used to control the granularity of compression updates through the forward pass. Set to True to use the weight-compressed outputs of each module, set to False to use the weight-compressed outputs of each - layer (transformer block) + layer (transformer block), defaults to False :param targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass From d23aabb1330aa408b3374eb79341e08a9b0e3f7b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 22 Oct 2024 23:46:58 +0000 Subject: [PATCH 037/285] use private attrs --- src/llmcompressor/modifiers/quantization/gptq/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 22404f0da..05a7ab0a6 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -12,7 +12,7 @@ ) from loguru import logger import warnings -from pydantic import Field, field_validator +from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory @@ -119,8 +119,8 @@ class GPTQModifier(Modifier): num_calibration_steps: Optional[int] = None scheme: Optional[Union[str, Dict[str, Any]]] = None - _quantization_modifier: Optional[QuantizationModifier] = None - _layer_compressor: Optional[SequentialLayerCompressor] = None + _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() + _layer_compressor: Optional[SequentialLayerCompressor] = PrivateAttr() @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: From 355074b2bf815ac9a06f30617c50b15d9ccd4364 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:08:23 +0000 Subject: [PATCH 038/285] more docstring --- .../modifiers/quantization/gptq/base.py | 2 +- .../modifiers/utils/layer_compressor.py | 27 ++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 05a7ab0a6..0f4f7de95 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -11,7 +12,6 @@ update_prefix_dict, ) from loguru import logger -import warnings from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index c2130068f..e3e5c6217 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -1,7 +1,7 @@ import contextlib import operator from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from compressed_tensors import get_execution_device @@ -61,10 +61,29 @@ def remove_hooks(self): class SequentialLayerCompressor(HooksMixin): + """ + Apply a given compression function to a model during the model's calibration + forward pass + + Lifecycle: + - QuantizationModifier.initialize(model) + - SequentialLayerCompressor(compress_fn) + - register_hooks(model) + - model.forward() + - compress_fn(name, target_module, args) + - remove_hooks() + + :param compress_fn: Function to be called on target modules + :param true_sequential: Used to control the granularity of compression updates + through the forward pass. Set to True to use the weight-compressed outputs + of each module, set to False to use the weight-compressed outputs of each + layer (transformer block), defaults to False + """ + def __init__( self, compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], float], - true_sequential: bool = True, + true_sequential: bool = False, ): HooksMixin.__init__(self) self.compress_fn = compress_fn @@ -74,7 +93,9 @@ def __init__( self._num_layers = 0 def register_hooks( - self, model: torch.nn.Module, sequential_targets: Union[str, List[str], None] + self, + model: torch.nn.Module, + sequential_targets: Optional[Union[str, List[str]]] = None, ): # find layers (used for printing even if true_sequential=True) # if no targets are provided, default to the modules that shouldn't be From bf2184d60bcc8671b3d6d6e587ba0dcfe75504da Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:18:43 +0000 Subject: [PATCH 039/285] docstrings --- .../modifiers/quantization/gptq/base.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 0f4f7de95..2c31d0a70 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -235,6 +235,13 @@ def on_finalize(self, state: "State", **kwargs) -> bool: def calibration_forward( self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader ): + """ + Perform calibration forward pass with one batch whose size is the size + of the dataset + + :param model: model to perform forward pass with + :param dataloader: dataloader containing calibration dataset + """ dataloader = create_single_batch_dataloader(dataloader.dataset) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) @@ -242,6 +249,15 @@ def calibration_forward( def quantize_module( self, name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...] ) -> float: + """ + Quantize a module's weight according to the GPTQ algorithm + + :param name: name of module being quantized + :param module: module being quantized + :param args: input arguments for module forward pass + + :return: total loss from applying weight quantization to this module + """ logger.info(f"Quantizing {name}...") # Assume that first argument is the input From 0b418c7efe2a6810179d942030dd40a0e7c38148 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:27:14 +0000 Subject: [PATCH 040/285] docstrings --- .../quantization/gptq/utils/gptq_quantize.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index b21956cee..4301d944a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -18,6 +18,13 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: + """ + Calculate the hessian with respect to the module inputs + + :param inp: module inputs + :param module_class: class of module, likely torch.nn.Linear + :return: hessian w.r.t. module inputs + """ inp = inp.to(device=device) if len(inp.shape) == 2: inp = inp.unsqueeze(0) @@ -36,6 +43,13 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: + """ + Performs in-place inversion of the hessian in order to save memory + + :param H: hessian being inverted + :param percdamp: dampening factor on hessian diagonal + :return: inverted hessian + """ damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(H.shape[0], device=H.device) H[diag, diag] += damp @@ -45,9 +59,18 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: return H -def compute_scale_zeropoint( +def compute_scale_zero_point( W: torch.Tensor, quant_args: QuantizationArgs ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the scale and zero point of a module weight + TODO: revisit after observers refactor + + :param W: module weight + :param quant_args: quantization arguments which determine how quantization + parameters are calculated + :return: scale and zero_point + """ # TODO: revisit after observers refactor scale, zero_point = quant_args.get_observer()(W, g_idx=None) @@ -64,6 +87,17 @@ def quantize_weight( percdamp: float = 0.01, module_class=torch.nn.Linear, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + """ + Quantize a module weight according to the GPTQ algorithm + + :param weight: weight being quantized + :param inp: module inputs used to calculate hessian + :param quant_args: quantization arguments used to find quantization parameters + :param blocksize: chunk size of quantization updates + :param percdamp: dampening factor on hessian diagonal + :param module_class: class of module, likely torch.nn.Linear + :return: loss, quantized_weight, scale, zero_point, g_idx + """ strategy = quant_args.strategy actorder = quant_args.actorder final_shape = weight.shape @@ -91,22 +125,22 @@ def quantize_weight( if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, H, perm = _apply_activation_ordering(W, H) - scale, zero_point = compute_scale_zeropoint(W, quant_args) + scale, zero_point = compute_scale_zero_point(W, quant_args) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - scale, zero_point = compute_scale_zeropoint(W, quant_args) + scale, zero_point = compute_scale_zero_point(W, quant_args) W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] else: - scale, zero_point = compute_scale_zeropoint(W, quant_args) + scale, zero_point = compute_scale_zero_point(W, quant_args) else: - scale, zero_point = compute_scale_zeropoint(W, quant_args) + scale, zero_point = compute_scale_zero_point(W, quant_args) # sparsity mask sparsity = tensor_sparsity(W) @@ -238,6 +272,8 @@ def _apply_activation_ordering( Permute weight and hessian in order of greatest outupt activations :param W: weight to permute + :param H: hessian used to determine activation ordering + :return: permuted weight, permuted hessian, permutation map """ perm = torch.argsort(torch.diag(H), descending=True) return W[:, perm], H[perm][:, perm], perm From 56cceeaccb6cd0ac54a39b53c678751ad807ccd5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:31:58 +0000 Subject: [PATCH 041/285] docstrings --- .../transformers/finetune/data/data_helpers.py | 6 ++++++ src/llmcompressor/utils/helpers.py | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 933f64bd9..8a6f097a3 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -25,6 +25,12 @@ def create_single_batch_dataloader( dataset: datasets.Dataset, ) -> torch.utils.data.DataLoader: + """ + Create a dataloader whose batch size is equal to the size of the dataset + + :param dataset: dataset used to generate dataloader + :return: dataloader + """ def pad_sequences(batch): # extract input_ids and attention_mask from the batch input_ids = [torch.tensor(item["input_ids"]) for item in batch] diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index c414d134b..211bd01eb 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1077,6 +1077,9 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): @contextlib.contextmanager def DisableQuantization(model: torch.nn.Module): + """ + Disable quantization from QuantizationModifier + """ model.apply(disable_quantization) yield model.apply(enable_quantization) @@ -1084,6 +1087,13 @@ def DisableQuantization(model: torch.nn.Module): @contextlib.contextmanager def calibration_forward_context(model: torch.nn.Module): + """ + Context in which all calibration forward passes should occur. + + - Remove gradient calculations + - Disable the KV cache + - Disable quantization from QuantizationModifier + """ model.eval() with ( From 7c7e3bc964921384472bb7f24d48e0759cc3e610 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:33:44 +0000 Subject: [PATCH 042/285] move hooksmixin to separate file --- .../quantization/gptq/utils/gptq_quantize.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 38 +++++++++++++++++++ .../modifiers/utils/layer_compressor.py | 35 +---------------- .../finetune/data/data_helpers.py | 3 +- 4 files changed, 42 insertions(+), 36 deletions(-) create mode 100644 src/llmcompressor/modifiers/utils/hooks.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 4301d944a..022252b0a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -20,7 +20,7 @@ def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: """ Calculate the hessian with respect to the module inputs - + :param inp: module inputs :param module_class: class of module, likely torch.nn.Linear :return: hessian w.r.t. module inputs diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py new file mode 100644 index 000000000..d7e35015f --- /dev/null +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -0,0 +1,38 @@ +import contextlib + +import torch + +__all__ = ["HooksMixin"] + + +class HooksMixin: + HOOKS_DISABLED: bool = False + + @classmethod + def hook(cls, func): + def wrapped(*args, **kwargs): + if cls.HOOKS_DISABLED: + return + + func(*args, **kwargs) + + return wrapped + + @classmethod + @contextlib.contextmanager + def disable_hooks(cls): + try: + cls.HOOKS_DISABLED = True + yield + finally: + cls.HOOKS_DISABLED = False + + def __init__(self): + self._hooks = [] + + def register_hook(self, handle: torch.utils.hooks.RemovableHandle): + self._hooks.append(handle) + + def remove_hooks(self): + for hook in self._hooks: + hook.remove() diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index e3e5c6217..b168e2534 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -1,4 +1,3 @@ -import contextlib import operator from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -9,6 +8,7 @@ from tqdm import tqdm from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper +from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.utils.fsdp.context import ( @@ -27,39 +27,6 @@ __all__ = ["SequentialLayerCompressor", "LayerCompressor"] -class HooksMixin: - HOOKS_DISABLED: bool = False - - @classmethod - def hook(cls, func): - def wrapped(*args, **kwargs): - if cls.HOOKS_DISABLED: - return - - func(*args, **kwargs) - - return wrapped - - @classmethod - @contextlib.contextmanager - def disable_hooks(cls): - try: - cls.HOOKS_DISABLED = True - yield - finally: - cls.HOOKS_DISABLED = False - - def __init__(self): - self._hooks = [] - - def register_hook(self, handle: torch.utils.hooks.RemovableHandle): - self._hooks.append(handle) - - def remove_hooks(self): - for hook in self._hooks: - hook.remove() - - class SequentialLayerCompressor(HooksMixin): """ Apply a given compression function to a model during the model's calibration diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 8a6f097a3..cc1c946ac 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -27,10 +27,11 @@ def create_single_batch_dataloader( ) -> torch.utils.data.DataLoader: """ Create a dataloader whose batch size is equal to the size of the dataset - + :param dataset: dataset used to generate dataloader :return: dataloader """ + def pad_sequences(batch): # extract input_ids and attention_mask from the batch input_ids = [torch.tensor(item["input_ids"]) for item in batch] From 2d52183760cebb80a9a87ce0c3e3a07796c799a7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 00:39:46 +0000 Subject: [PATCH 043/285] docstrings --- src/llmcompressor/modifiers/utils/hooks.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index d7e35015f..19c9a34ce 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -6,6 +6,15 @@ class HooksMixin: + """" + Class to manage the registration, disabling, and removal of hooks. Registering + and removing hooks should be handled by modifier classes which inherit from this + mixin, while disabling hooks should disable all hooks across modifiers. + + Modifiers which implement hooks should use the @HooksMixin.hook decorator + Modifiers must pass registered hooks handles to self.register_hook() and must + remove hooks when finished using self.remove_hooks() + """ HOOKS_DISABLED: bool = False @classmethod @@ -21,6 +30,10 @@ def wrapped(*args, **kwargs): @classmethod @contextlib.contextmanager def disable_hooks(cls): + """ + Disable all hooks across all modifiers + TODO: select which modifier hooks are disabled/ kept enabled + """ try: cls.HOOKS_DISABLED = True yield @@ -31,8 +44,16 @@ def __init__(self): self._hooks = [] def register_hook(self, handle: torch.utils.hooks.RemovableHandle): + """ + Usage: self.register_hook(module.register_forward_hook(...)) + + :param handle: handle of added hook + """ self._hooks.append(handle) def remove_hooks(self): + """ + Remove all hooks belonging to a modifier + """ for hook in self._hooks: hook.remove() From 9081f12f0239860344fe0fd5d86988db3cd88ecb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 23 Oct 2024 01:18:33 -0400 Subject: [PATCH 044/285] fix docstring, better arguments grouping --- .../modifiers/quantization/gptq/base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 2c31d0a70..15f16727e 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -77,8 +77,8 @@ class GPTQModifier(Modifier): through the forward pass. Set to True to use the weight-compressed outputs of each module, set to False to use the weight-compressed outputs of each layer (transformer block), defaults to False - :param targets: list of layer names to compress during GPTQ, or '__ALL__' - to compress every layer in the model + :param sequential_targets: list of layer names to compress during GPTQ, or + '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass :param quantize: Set to True to quantize using an existing quantization modifier, or pass in the configuration for a quantization modifier if one does not @@ -108,16 +108,18 @@ class GPTQModifier(Modifier): sequential_update: bool = True # DEPRECIATED true_sequential: bool = False - targets: Union[str, List[str], None] = None sequential_targets: Union[str, List[str], None] = None block_size: int = 128 - quantize: Union[bool, Dict] = True dampening_frac: Optional[float] = 0.01 + quantize: Union[bool, Dict] = True + + # arguments used for quant modifier config_groups: Optional[Dict[str, QuantizationScheme]] = None + scheme: Optional[Union[str, Dict[str, Any]]] = None + targets: Union[str, List[str], None] = None ignore: List[str] = Field(default_factory=list) - disable_quantization_observer_epoch: Optional[float] = None num_calibration_steps: Optional[int] = None - scheme: Optional[Union[str, Dict[str, Any]]] = None + disable_quantization_observer_epoch: Optional[float] = None _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() _layer_compressor: Optional[SequentialLayerCompressor] = PrivateAttr() From 96e9496f266059494e0e88fd190606bfc5835ccd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 24 Oct 2024 03:53:00 +0000 Subject: [PATCH 045/285] use LayerCompressorMixin --- .../modifiers/quantization/gptq/base.py | 23 ++- src/llmcompressor/modifiers/utils/hooks.py | 148 ++++++++++++++++-- .../modifiers/utils/layer_compressor.py | 127 +-------------- 3 files changed, 148 insertions(+), 150 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 15f16727e..177db29e4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -20,7 +20,7 @@ quantize_weight, ) from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier -from llmcompressor.modifiers.utils.layer_compressor import SequentialLayerCompressor +from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.transformers.finetune.data.data_helpers import ( create_single_batch_dataloader, @@ -35,7 +35,7 @@ __all__ = ["GPTQModifier"] -class GPTQModifier(Modifier): +class GPTQModifier(Modifier, LayerCompressorMixin): """ Modifier for applying the one-shot OBCQ algorithm to a model @@ -122,7 +122,6 @@ class GPTQModifier(Modifier): disable_quantization_observer_epoch: Optional[float] = None _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() - _layer_compressor: Optional[SequentialLayerCompressor] = PrivateAttr() @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -135,13 +134,6 @@ def validate_sequential_update(cls, value: bool) -> bool: return True - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._layer_compressor = SequentialLayerCompressor( - self.quantize_module, self.true_sequential - ) - def on_initialize_structure(self, state: State, **kwargs): """ Check the model's quantization state matches that expected by this modifier, @@ -210,7 +202,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # add hooks to targets and layers # after lifecycle refactor, move this to pre_batch - self._layer_compressor.register_hooks(state.model, self.sequential_targets) + self.register_hooks(state.model) # apply calibration and trigger hooks self.calibration_forward(state.model, state.data.calib) @@ -230,7 +222,7 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self._quantization_modifier: self._quantization_modifier.finalize(state, **kwargs) - self._layer_compressor.remove_hooks() + self.remove_hooks() return True @@ -248,8 +240,11 @@ def calibration_forward( with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) - def quantize_module( - self, name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...] + def compress_module( + self, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], ) -> float: """ Quantize a module's weight according to the GPTQ algorithm diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 19c9a34ce..d65e41d01 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,12 +1,22 @@ import contextlib +from abc import abstractmethod +from functools import partial +from typing import Any, Callable, ClassVar, Dict, List, Tuple import torch +from loguru import logger +from pydantic import BaseModel +from torch.utils.hooks import RemovableHandle -__all__ = ["HooksMixin"] +from llmcompressor.utils.helpers import getattr_chain +from llmcompressor.utils.metric_logging import CompressionLogger +from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params +__all__ = ["HooksMixin", "LayerCompressorMixin"] -class HooksMixin: - """" + +class HooksMixin(BaseModel): + """ " Class to manage the registration, disabling, and removal of hooks. Registering and removing hooks should be handled by modifier classes which inherit from this mixin, while disabling hooks should disable all hooks across modifiers. @@ -15,12 +25,14 @@ class HooksMixin: Modifiers must pass registered hooks handles to self.register_hook() and must remove hooks when finished using self.remove_hooks() """ - HOOKS_DISABLED: bool = False + + _HOOKS_DISABLED: ClassVar[bool] = False + _hooks: List[RemovableHandle] = [] @classmethod - def hook(cls, func): + def hook(cls, func: Callable[[Any], Any]): def wrapped(*args, **kwargs): - if cls.HOOKS_DISABLED: + if cls._HOOKS_DISABLED: return func(*args, **kwargs) @@ -35,15 +47,12 @@ def disable_hooks(cls): TODO: select which modifier hooks are disabled/ kept enabled """ try: - cls.HOOKS_DISABLED = True + cls._HOOKS_DISABLED = True yield finally: - cls.HOOKS_DISABLED = False - - def __init__(self): - self._hooks = [] + cls._HOOKS_DISABLED = False - def register_hook(self, handle: torch.utils.hooks.RemovableHandle): + def register_hook(self, handle: RemovableHandle): """ Usage: self.register_hook(module.register_forward_hook(...)) @@ -57,3 +66,118 @@ def remove_hooks(self): """ for hook in self._hooks: hook.remove() + + +class LayerCompressorMixin(HooksMixin): + """ + Apply a given compression function to a model during the model's calibration + forward pass + + Lifecycle: + - QuantizationModifier.initialize(model) + - SequentialLayerCompressor(compress_fn) + - register_hooks(model) + - model.forward() + - compress_fn(name, target_module, args) + - remove_hooks() + + :ivar true_sequential: Used to control the granularity of compression updates + through the forward pass. Set to True to use the weight-compressed outputs + of each module, set to False to use the weight-compressed outputs of each + layer (transformer block), defaults to False + :ivar sequential_targets: list of layer names to compress during GPTQ, or + '__ALL__' to compress every layer in the model + :ivar compresss_module: Function to be called on target modules + """ + + true_sequential: bool + sequential_targets: bool + # compress_module: Callable[[str, torch.nn.Module, Tuple], float] + + _layer_index = 0 + _num_layers = 0 + + @abstractmethod + def compress_module( + self, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + ) -> float: + raise NotImplementedError() + + def register_hooks(self, model: torch.nn.Module): + # find layers (used for printing even if true_sequential=True) + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + sequential_targets = self.sequential_targets + if sequential_targets is None: + sequential_targets = get_no_split_params(model) + layers = get_layers(sequential_targets, model) + self._num_layers = len(layers) + + for name, module in model.named_modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + pre_hook = partial(self.target_pre_forward, name) + post_hook = partial(self.target_post_forward, name) + self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_hook(post_hook)) + + if name in layers.keys(): + pre_hook = partial(self.layer_pre_forward, name) + post_hook = partial(self.layer_post_forward, name) + self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook( + module.register_forward_hook(post_hook, with_kwargs=True) + ) + + @HooksMixin.hook + def target_pre_forward( + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...] + ): + if self.true_sequential: + # compress first so output is from compressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + @HooksMixin.hook + def target_post_forward( + self, + name: str, + module: torch.nn.Module, + args: Tuple[Any, ...], + _output: Tuple[Any, ...], + ): + if not self.true_sequential: + # compress after so output is from uncompressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + @HooksMixin.hook + def layer_pre_forward(self, _name: str, _module: torch.nn.Module, _args: Any): + logger.info( + f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" + ) + + @HooksMixin.hook + def layer_post_forward( + self, + _name: str, + module: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + output: Tuple[Any, ...], + ): + if not self.true_sequential: + # rerun with (now) compressed weights + with HooksMixin.disable_hooks(): + compressed_output = module(*args, **kwargs) + + error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) + logger.info(f"Mean output error from quantization: {error:.3f}") + + self._layer_index += 1 + return output diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index b168e2534..3f3aa3d02 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -1,6 +1,5 @@ import operator -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, Tuple import torch from compressed_tensors import get_execution_device @@ -8,135 +7,15 @@ from tqdm import tqdm from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper -from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.utils.fsdp.context import ( fix_fsdp_module_name, summon_full_params_context, ) -from llmcompressor.utils.helpers import getattr_chain -from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - get_prunable_layers, - set_layer, -) - -__all__ = ["SequentialLayerCompressor", "LayerCompressor"] - - -class SequentialLayerCompressor(HooksMixin): - """ - Apply a given compression function to a model during the model's calibration - forward pass - - Lifecycle: - - QuantizationModifier.initialize(model) - - SequentialLayerCompressor(compress_fn) - - register_hooks(model) - - model.forward() - - compress_fn(name, target_module, args) - - remove_hooks() - - :param compress_fn: Function to be called on target modules - :param true_sequential: Used to control the granularity of compression updates - through the forward pass. Set to True to use the weight-compressed outputs - of each module, set to False to use the weight-compressed outputs of each - layer (transformer block), defaults to False - """ - - def __init__( - self, - compress_fn: Callable[[str, torch.nn.Module, torch.Tensor], float], - true_sequential: bool = False, - ): - HooksMixin.__init__(self) - self.compress_fn = compress_fn - self.true_sequential = true_sequential - - self._layer_index = 0 - self._num_layers = 0 - - def register_hooks( - self, - model: torch.nn.Module, - sequential_targets: Optional[Union[str, List[str]]] = None, - ): - # find layers (used for printing even if true_sequential=True) - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - if sequential_targets is None: - sequential_targets = get_no_split_params(model) - layers = get_layers(sequential_targets, model) - self._num_layers = len(layers) - - for name, module in model.named_modules(): - if getattr_chain(module, "quantization_scheme.weights", None) is not None: - pre_hook = partial(self.target_pre_forward, name) - post_hook = partial(self.target_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook)) - self.register_hook(module.register_forward_hook(post_hook)) - - if name in layers.keys(): - pre_hook = partial(self.layer_pre_forward, name) - post_hook = partial(self.layer_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook)) - self.register_hook( - module.register_forward_hook(post_hook, with_kwargs=True) - ) - - @HooksMixin.hook - def target_pre_forward( - self, name: str, module: torch.nn.Module, args: Tuple[Any, ...] - ): - if self.true_sequential: - # compress first so output is from compressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_fn(name, module, args) - comp_logger.set_loss(loss) - - @HooksMixin.hook - def target_post_forward( - self, - name: str, - module: torch.nn.Module, - args: Tuple[Any, ...], - _output: Tuple[Any, ...], - ): - if not self.true_sequential: - # compress after so output is from uncompressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_fn(name, module, args) - comp_logger.set_loss(loss) - - @HooksMixin.hook - def layer_pre_forward(self, _name: str, _module: torch.nn.Module, _args: Any): - logger.info( - f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" - ) - - @HooksMixin.hook - def layer_post_forward( - self, - name: str, - module: torch.nn.Module, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - output: Tuple[Any, ...], - ): - if not self.true_sequential: - # rerun with (now) compressed weights - with HooksMixin.disable_hooks(): - compressed_output = module(*args, **kwargs) - - error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) - logger.info(f"Mean output error from quantization: {error:.3f}") +from llmcompressor.utils.pytorch.module import get_prunable_layers, set_layer - self._layer_index += 1 - return output +__all__ = ["LayerCompressor"] class LayerCompressor: From 7fbf8b193f193047a8d951fb4552a41799c327e1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 24 Oct 2024 03:56:21 +0000 Subject: [PATCH 046/285] docstrings --- src/llmcompressor/modifiers/utils/hooks.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index d65e41d01..f7242124c 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -24,6 +24,11 @@ class HooksMixin(BaseModel): Modifiers which implement hooks should use the @HooksMixin.hook decorator Modifiers must pass registered hooks handles to self.register_hook() and must remove hooks when finished using self.remove_hooks() + + Lifecycle: + - Modifier.register_hooks(model) + - model.forward() + - Modifier.remove_hooks() """ _HOOKS_DISABLED: ClassVar[bool] = False @@ -75,11 +80,10 @@ class LayerCompressorMixin(HooksMixin): Lifecycle: - QuantizationModifier.initialize(model) - - SequentialLayerCompressor(compress_fn) - - register_hooks(model) + - Modifier.register_hooks(model) - model.forward() - compress_fn(name, target_module, args) - - remove_hooks() + - Modifier.remove_hooks() :ivar true_sequential: Used to control the granularity of compression updates through the forward pass. Set to True to use the weight-compressed outputs From 3d3af2ad0d7dd3b4d63176f9867aeb97f4a8cafd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 24 Oct 2024 04:49:55 +0000 Subject: [PATCH 047/285] add back hessian hook to support bs1 --- .../modifiers/quantization/gptq/base.py | 76 ++++++++++----- .../quantization/gptq/utils/gptq_quantize.py | 28 +++++- src/llmcompressor/utils/fsdp/helpers.py | 93 +++++++++++++++++++ 3 files changed, 170 insertions(+), 27 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 177db29e4..ed9e80269 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -17,6 +17,7 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( + add_batch, quantize_weight, ) from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier @@ -25,6 +26,7 @@ from llmcompressor.transformers.finetune.data.data_helpers import ( create_single_batch_dataloader, ) +from llmcompressor.utils.fsdp.helpers import has_offloaded_params, register_offload_parameter from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, @@ -200,15 +202,17 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - # add hooks to targets and layers - # after lifecycle refactor, move this to pre_batch - self.register_hooks(state.model) + # trigger hessian hooks + self.register_hessians(state.model) + with calibration_forward_context(state.model): + run_calibration_forward(state.model, state.data.calib, mask_padding=True) + self.remove_hooks() - # apply calibration and trigger hooks - self.calibration_forward(state.model, state.data.calib) + self.register_hooks(state.model) + state.model(**state.model.dummy_inputs) + self.remove_hooks() # freeze quantization - # after lifecycle refactor, move this to post_batch state.model.apply(freeze_module_quantization) return True @@ -222,9 +226,31 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self._quantization_modifier: self._quantization_modifier.finalize(state, **kwargs) - self.remove_hooks() - return True + + def hessian_hook(self, module, args): + # onload and offload + module.gptq_hessian = add_batch( + module.gptq_hessian.to(args[0].device), + module.gptq_hessian_samples, + module, + args[0] + ).to("cpu") + module.gptq_hessian_samples += 1 + + def register_hessians(self, model: torch.nn.Module): + for module in model.modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + num_columns = module.weight.shape[1] + + # hessian starts offloaded + module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device="cpu") + module.gptq_hessian_samples = 0 + + self.register_hook(module.register_forward_pre_hook(self.hessian_hook)) + + + def calibration_forward( self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader @@ -261,24 +287,26 @@ def compress_module( inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - with align_module(module): - loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( - module.weight.data, - inp, - quant_args, - blocksize=self.block_size, - percdamp=self.dampening_frac, - module_class=type(module), - ) + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + module.weight.data, + module.gptq_hessian.data.to(module.weight.device), + quant_args, + blocksize=self.block_size, + percdamp=self.dampening_frac, + module_class=type(module), + ) + + delattr(module, "gptq_hessian") + delattr(module, "gptq_hessian_samples") - # FUTURE: Implement learning rate modification to weight update + # FUTURE: Implement learning rate modification to weight update - if is_module_offloaded(module): - update_prefix_dict(self.layer, "weight", quantized_weight) - update_parameter_data(module, quantized_weight, "weight") - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - update_parameter_data(module, g_idx, "weight_g_idx") + if is_module_offloaded(module): + update_prefix_dict(self.layer, "weight", quantized_weight) + update_parameter_data(module, quantized_weight, "weight") + update_parameter_data(module, scale, "weight_scale") + update_parameter_data(module, zero_point, "weight_zero_point") + update_parameter_data(module, g_idx, "weight_g_idx") return loss diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 022252b0a..1ee435d4d 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -17,6 +17,28 @@ GPTQ_PRECISION = torch.float32 +def add_batch(H: torch.Tensor, nsamples: int , module: torch.nn.Module, inp: torch.Tensor): + """ + Add a batch of layer input and output data to the Hessian calculation + """ + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(module, torch.nn.Linear) or isinstance( + module, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + H *= nsamples / (nsamples + tmp) + nsamples += tmp + inp = inp.to(dtype=H.dtype) + inp = math.sqrt(2 / nsamples) * inp + H += inp.matmul(inp.t()) + + return H + + def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: """ Calculate the hessian with respect to the module inputs @@ -81,7 +103,7 @@ def compute_scale_zero_point( def quantize_weight( weight: torch.Tensor, - inp: torch.Tensor, + H: torch.Tensor, #inp: torch.Tensor, quant_args: QuantizationArgs, blocksize: int = 128, percdamp: float = 0.01, @@ -91,7 +113,7 @@ def quantize_weight( Quantize a module weight according to the GPTQ algorithm :param weight: weight being quantized - :param inp: module inputs used to calculate hessian + # :param inp: module inputs used to calculate hessian :param quant_args: quantization arguments used to find quantization parameters :param blocksize: chunk size of quantization updates :param percdamp: dampening factor on hessian diagonal @@ -104,7 +126,7 @@ def quantize_weight( final_dtype = weight.dtype W = weight.data.clone() - H = compute_hessian(inp, module_class, device=weight.device) + #H = compute_hessian(inp, module_class, device=weight.device) # standardize shape and dtype if module_class == torch.nn.Conv2d: diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 8cc0f5405..80ef733f1 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,9 +1,12 @@ +import contextlib import operator from pathlib import Path from typing import Optional from loguru import logger +from llmcompressor.utils.helpers import getattr_chain + try: from torch.distributed.fsdp import ( FullStateDictConfig, @@ -179,3 +182,93 @@ def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: parent = operator.attrgetter(parent_name)(model) return parent + +def has_offloaded_params(module: torch.nn.Module) -> bool: + """ + Checks if a module has offloaded parameters by checking if the given module + has a AlignDevicesHook attached with offloading enabled + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, + `False` otherwise. + """ + from accelerate.hooks import AlignDevicesHook + + return ( + hasattr(module, "_hf_hook") and + isinstance(module._hf_hook, AlignDevicesHook) and + module._hf_hook.offload + ) + +@contextlib.contextmanager +def align_module( + module: torch.nn.Module, + execution_device: Optional[torch.device] = None, + args = tuple(), kwargs = dict() +): + """ + Move a module's parameters to the execution device + :param module: module with parameters to align + :param execution_device: if provided, overrides module execution device + within the context + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = original_device + + module._hf_hook.pre_forward(module, *args, **kwargs) + yield + module._hf_hook.post_forward(module, None) + + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = {} + for name, param in module.named_parameters(): + devices[name] = param.device + setattr(module, name, param.to(execution_device)) + + yield + + for name, param_device in module.named_parameters: + setattr(module, name, param.to(param_device)) + + else: + yield + + +def update_offload_parameter( + module: torch.nn.Module, + name: str, + data: torch.Tensor, + init_device: Optional[torch.device] = torch.device("cpu"), +): + """ + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param init_device: offload device for newly registered parameters + """ + param = getattr(module, name) + param.data = data + + prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if prefix_dict is not None: + prefix = module._hf_hook.weights_map.prefix + key = f"{prefix}{name}" + + offload_device = prefix_dict[key].device if key in prefix_dict else init_device + prefix_dict[key] = data.to(device=offload_device) + + +def register_offload_parameter( + module: torch.nn.Module, + name: str, + data: torch.Tensor, + offload_device: Optional[torch.device] = torch.device("cpu"), +): + module.register_parameter(name, torch.nn.Parameter(data)) + update_offload_parameter(module, name, data, offload_device) \ No newline at end of file From b3021ab9e8d30aeb62786a11e464fcbb1ec5898f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 25 Oct 2024 16:11:33 +0000 Subject: [PATCH 048/285] wip --- .../modifiers/quantization/gptq/base.py | 30 +------- src/llmcompressor/modifiers/utils/hooks.py | 69 +++++++++++++++++-- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index ed9e80269..d18e74249 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -203,13 +203,11 @@ def on_initialize(self, state: "State", **kwargs) -> bool: raise ValueError("To use the GPTQModifier, quantization must be enabled.") # trigger hessian hooks - self.register_hessians(state.model) + self.register_hooks(state.model) with calibration_forward_context(state.model): run_calibration_forward(state.model, state.data.calib, mask_padding=True) - self.remove_hooks() - self.register_hooks(state.model) - state.model(**state.model.dummy_inputs) + #state.model(**state.model.dummy_inputs) self.remove_hooks() # freeze quantization @@ -227,30 +225,6 @@ def on_finalize(self, state: "State", **kwargs) -> bool: self._quantization_modifier.finalize(state, **kwargs) return True - - def hessian_hook(self, module, args): - # onload and offload - module.gptq_hessian = add_batch( - module.gptq_hessian.to(args[0].device), - module.gptq_hessian_samples, - module, - args[0] - ).to("cpu") - module.gptq_hessian_samples += 1 - - def register_hessians(self, model: torch.nn.Module): - for module in model.modules(): - if getattr_chain(module, "quantization_scheme.weights", None) is not None: - num_columns = module.weight.shape[1] - - # hessian starts offloaded - module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device="cpu") - module.gptq_hessian_samples = 0 - - self.register_hook(module.register_forward_pre_hook(self.hessian_hook)) - - - def calibration_forward( self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index f7242124c..da1094872 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,13 +1,15 @@ import contextlib from abc import abstractmethod from functools import partial -from typing import Any, Callable, ClassVar, Dict, List, Tuple +from typing import Any, Callable, ClassVar, Dict, List, Set, Tuple import torch from loguru import logger from pydantic import BaseModel from torch.utils.hooks import RemovableHandle +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import add_batch +from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.utils.helpers import getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params @@ -100,6 +102,9 @@ class LayerCompressorMixin(HooksMixin): _layer_index = 0 _num_layers = 0 + _pre_active: Set[torch.nn.Module] = set() + _module_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] + _module_outputs: List[Tuple[Any, ...]] = [] @abstractmethod def compress_module( @@ -125,7 +130,7 @@ def register_hooks(self, model: torch.nn.Module): if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) post_hook = partial(self.target_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) if name in layers.keys(): @@ -138,22 +143,74 @@ def register_hooks(self, model: torch.nn.Module): @HooksMixin.hook def target_pre_forward( - self, name: str, module: torch.nn.Module, args: Tuple[Any, ...] + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): - if self.true_sequential: - # compress first so output is from compressed weights + if module in self._pre_active: + return + + if not hasattr(module, "gptq_hessian"): + print("init hessian") + num_columns = module.weight.shape[1] + + # hessian starts offloaded + module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device="cpu") + module.gptq_hessian_samples = 0 + + print("add to hessian") + # onload and offload + module.gptq_hessian = add_batch( + module.gptq_hessian.to(args[0].device), + module.gptq_hessian_samples, + module, + args[0] + ).to("cpu") + module.gptq_hessian_samples += 1 + self._module_inputs.append((args, kwargs)) + + if module.gptq_hessian_samples >= 2: + print("compress") with CompressionLogger(module) as comp_logger: loss = self.compress_module(name, module, args) comp_logger.set_loss(loss) + self._pre_active.add(module) + for args, kwargs in self._module_inputs: + try: + module(*args, **kwargs) + except EarlyStopException: + pass + + raise EarlyStopException(torch.Tensor([]), None) + @HooksMixin.hook def target_post_forward( self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], - _output: Tuple[Any, ...], + output: Tuple[Any, ...], ): + print("target_post_forward") + return + # accumulate + self._module_outputs.append(output) + + if len(self._module_outputs) == 2: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + ret = self._module_outputs + self._module_outputs = [] + + return ret + + if self.true_sequential: + # compress first so output is from compressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + if not self.true_sequential: # compress after so output is from uncompressed weights with CompressionLogger(module) as comp_logger: From 8508b633f14d03e92f762a1fc98818b09ffefd98 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 25 Oct 2024 18:37:42 +0000 Subject: [PATCH 049/285] accumulate --- .../modifiers/quantization/gptq/base.py | 4 +- .../quantization/gptq/utils/gptq_quantize.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 72 +++++++++++-------- 3 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index d18e74249..68d603a4b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -202,10 +202,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - # trigger hessian hooks self.register_hooks(state.model) - with calibration_forward_context(state.model): - run_calibration_forward(state.model, state.data.calib, mask_padding=True) + self.calibration_forward(state.model, state.data.calib) #state.model(**state.model.dummy_inputs) self.remove_hooks() diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 1ee435d4d..4365ce3d3 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -36,7 +36,7 @@ def add_batch(H: torch.Tensor, nsamples: int , module: torch.nn.Module, inp: tor inp = math.sqrt(2 / nsamples) * inp H += inp.matmul(inp.t()) - return H + return H, nsamples def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index da1094872..83f50d6b1 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,7 +1,7 @@ import contextlib from abc import abstractmethod from functools import partial -from typing import Any, Callable, ClassVar, Dict, List, Set, Tuple +from typing import Any, Callable, ClassVar, Dict, List, Set, Tuple, Union import torch from loguru import logger @@ -104,7 +104,10 @@ class LayerCompressorMixin(HooksMixin): _num_layers = 0 _pre_active: Set[torch.nn.Module] = set() _module_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] - _module_outputs: List[Tuple[Any, ...]] = [] + _module_outputs: Union[List[Tuple[Any, ...]], torch.Tensor] = [] + + _layer_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] + _layer_outputs: List[Tuple[Any, ...]] = [] @abstractmethod def compress_module( @@ -136,7 +139,7 @@ def register_hooks(self, model: torch.nn.Module): if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook)) + self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook( module.register_forward_hook(post_hook, with_kwargs=True) ) @@ -145,42 +148,42 @@ def register_hooks(self, model: torch.nn.Module): def target_pre_forward( self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): - if module in self._pre_active: - return + input = args[0] + # compute hessian if not hasattr(module, "gptq_hessian"): - print("init hessian") - num_columns = module.weight.shape[1] - # hessian starts offloaded - module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device="cpu") + num_columns = module.weight.shape[1] + module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device=input.device) module.gptq_hessian_samples = 0 - print("add to hessian") - # onload and offload - module.gptq_hessian = add_batch( - module.gptq_hessian.to(args[0].device), + module.gptq_hessian, module.gptq_hessian_samples = add_batch( + module.gptq_hessian, module.gptq_hessian_samples, module, - args[0] - ).to("cpu") - module.gptq_hessian_samples += 1 - self._module_inputs.append((args, kwargs)) - + input + ) + if module.gptq_hessian_samples >= 2: - print("compress") - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) + # if true, compress + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) - self._pre_active.add(module) - for args, kwargs in self._module_inputs: - try: - module(*args, **kwargs) - except EarlyStopException: - pass + else: + raise EarlyStopException(torch.Tensor([]), None) - raise EarlyStopException(torch.Tensor([]), None) + # forward with individuals + forward_call = (module._slow_forward if torch._C._get_tracing_state() else module.forward) + self._module_outputs = [ + forward_call(input[batch_index: batch_index + 1]) + for batch_index in range(input.shape[0]) + ] + + self._module_outputs = torch.concat(self._module_outputs) + + return (input[0:1], *args[1:]), kwargs @HooksMixin.hook def target_post_forward( @@ -191,7 +194,11 @@ def target_post_forward( output: Tuple[Any, ...], ): print("target_post_forward") - return + + ret = self._module_outputs + self._module_outputs = [] + return ret + # accumulate self._module_outputs.append(output) @@ -218,11 +225,14 @@ def target_post_forward( comp_logger.set_loss(loss) @HooksMixin.hook - def layer_pre_forward(self, _name: str, _module: torch.nn.Module, _args: Any): + def layer_pre_forward(self, _name: str, layer: torch.nn.Module, _args: Any, kwargs): logger.info( f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) + + + @HooksMixin.hook def layer_post_forward( self, From 3ff271d87fec32b995c5d76d409abefd3712c388 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 25 Oct 2024 21:16:21 +0000 Subject: [PATCH 050/285] virtualize batches for layers --- src/llmcompressor/modifiers/utils/hooks.py | 98 +++++++++++++++------- 1 file changed, 68 insertions(+), 30 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 83f50d6b1..d5331f640 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -7,6 +7,7 @@ from loguru import logger from pydantic import BaseModel from torch.utils.hooks import RemovableHandle +from collections import defaultdict from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import add_batch from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException @@ -42,7 +43,7 @@ def wrapped(*args, **kwargs): if cls._HOOKS_DISABLED: return - func(*args, **kwargs) + return func(*args, **kwargs) return wrapped @@ -103,8 +104,8 @@ class LayerCompressorMixin(HooksMixin): _layer_index = 0 _num_layers = 0 _pre_active: Set[torch.nn.Module] = set() - _module_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] - _module_outputs: Union[List[Tuple[Any, ...]], torch.Tensor] = [] + _module_inputs: Dict[torch.nn.Module, List[Tuple[Tuple[Any, ...], Dict[str, Any]]]] = defaultdict(lambda: []) + _module_outputs: Dict[torch.nn.Module, Union[List[Tuple[Any, ...]], torch.Tensor]] = defaultdict(lambda: []) _layer_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] _layer_outputs: List[Tuple[Any, ...]] = [] @@ -143,6 +144,7 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook( module.register_forward_hook(post_hook, with_kwargs=True) ) + @HooksMixin.hook def target_pre_forward( @@ -152,11 +154,11 @@ def target_pre_forward( # compute hessian if not hasattr(module, "gptq_hessian"): - num_columns = module.weight.shape[1] module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device=input.device) module.gptq_hessian_samples = 0 + print(f"{name} adding {input.size(0)} samples") module.gptq_hessian, module.gptq_hessian_samples = add_batch( module.gptq_hessian, module.gptq_hessian_samples, @@ -164,26 +166,6 @@ def target_pre_forward( input ) - if module.gptq_hessian_samples >= 2: - # if true, compress - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - else: - raise EarlyStopException(torch.Tensor([]), None) - - # forward with individuals - forward_call = (module._slow_forward if torch._C._get_tracing_state() else module.forward) - self._module_outputs = [ - forward_call(input[batch_index: batch_index + 1]) - for batch_index in range(input.shape[0]) - ] - - self._module_outputs = torch.concat(self._module_outputs) - - return (input[0:1], *args[1:]), kwargs @HooksMixin.hook def target_post_forward( @@ -193,10 +175,21 @@ def target_post_forward( args: Tuple[Any, ...], output: Tuple[Any, ...], ): - print("target_post_forward") + print(f"post {name}") - ret = self._module_outputs - self._module_outputs = [] + if module.gptq_hessian_samples >= 512: + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + """ + breakpoint() + ret = torch.concat(self._module_outputs) + del self._module_inputs[module] + del self._module_outputs[module] return ret # accumulate @@ -223,25 +216,70 @@ def target_post_forward( with CompressionLogger(module) as comp_logger: loss = self.compress_module(name, module, args) comp_logger.set_loss(loss) + """ @HooksMixin.hook - def layer_pre_forward(self, _name: str, layer: torch.nn.Module, _args: Any, kwargs): + def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs): logger.info( f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) - + input = args[0] + + if not self.true_sequential: + self._module_inputs[layer] += [ + input[batch_index: batch_index + 1] + for batch_index in range(input.shape[0]) + ] + + # forward with individuals (might not be necessary) + forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) + self._module_outputs[layer] = [] + for batch_index in range(input.size(0) - 1): + print("layer forward") + output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) + self._module_outputs[layer].append(output) + pass + + # last sample can be passed normally + print("last layer forward") + + return (input[-1:], *args[1:]), kwargs @HooksMixin.hook def layer_post_forward( self, - _name: str, + name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any], output: Tuple[Any, ...], ): + print(f"post {name}") + breakpoint() + + # capture last sample + self._module_outputs[module].append(output) + + # batch outputs + outputs = self._module_outputs[module] + batched_outputs = tuple( + torch.concat(tuple( + outputs[sample_index][output_index] + for sample_index in range(len(outputs)) + )) + for output_index in range(len(outputs[0])) + ) + del self._module_outputs[module] + + if not self.true_sequential: + pass # run again + + del self._module_inputs[module] + + return batched_outputs + if not self.true_sequential: # rerun with (now) compressed weights with HooksMixin.disable_hooks(): From d6c6dc339381cf5eb893e6134604b12a25fc6127 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 25 Oct 2024 22:02:10 +0000 Subject: [PATCH 051/285] maybe works, but padding is wrong --- .../modifiers/quantization/gptq/base.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 37 ++++++++++--------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 68d603a4b..767664640 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -234,7 +234,7 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - dataloader = create_single_batch_dataloader(dataloader.dataset) + #dataloader = create_single_batch_dataloader(dataloader.dataset) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index d5331f640..322f7b787 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -177,7 +177,7 @@ def target_post_forward( ): print(f"post {name}") - if module.gptq_hessian_samples >= 512: + if module.gptq_hessian_samples >= 20: # compress print(f"compressing {name}") if True: #self.true_sequential: @@ -232,26 +232,27 @@ def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs for batch_index in range(input.shape[0]) ] - # forward with individuals (might not be necessary) - forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) - self._module_outputs[layer] = [] - for batch_index in range(input.size(0) - 1): - print("layer forward") - output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) - self._module_outputs[layer].append(output) - pass - # last sample can be passed normally - print("last layer forward") + if len(self._module_outputs[layer]) >= 20 - 1: + # last sample can be passed normally + print("last layer forward") + return (input[-1:], *args[1:]), kwargs + + else: + forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) + for batch_index in range(input.size(0)): + print("layer forward") + output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) + self._module_outputs[layer].append(output) - return (input[-1:], *args[1:]), kwargs + raise EarlyStopException(torch.tensor([]), None) @HooksMixin.hook def layer_post_forward( self, name: str, - module: torch.nn.Module, + layer: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any], output: Tuple[Any, ...], @@ -260,10 +261,10 @@ def layer_post_forward( breakpoint() # capture last sample - self._module_outputs[module].append(output) + self._module_outputs[layer].append(output) # batch outputs - outputs = self._module_outputs[module] + outputs = self._module_outputs[layer] batched_outputs = tuple( torch.concat(tuple( outputs[sample_index][output_index] @@ -271,19 +272,19 @@ def layer_post_forward( )) for output_index in range(len(outputs[0])) ) - del self._module_outputs[module] + del self._module_outputs[layer] if not self.true_sequential: pass # run again - del self._module_inputs[module] + del self._module_inputs[layer] return batched_outputs if not self.true_sequential: # rerun with (now) compressed weights with HooksMixin.disable_hooks(): - compressed_output = module(*args, **kwargs) + compressed_output = layer(*args, **kwargs) error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) logger.info(f"Mean output error from quantization: {error:.3f}") From 400fa0864875b15bbf0ff15d1a85db40b3e9c656 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 14:54:42 +0000 Subject: [PATCH 052/285] WIP --- .../modifiers/quantization/gptq/base.py | 42 ++-- .../quantization/gptq/utils/gptq_quantize.py | 19 +- src/llmcompressor/modifiers/utils/hooks.py | 81 +++----- src/llmcompressor/utils/fsdp/helpers.py | 188 +++++++++++++++--- 4 files changed, 218 insertions(+), 112 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 767664640..7e6e5556d 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -259,26 +259,34 @@ def compress_module( inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( - module.weight.data, - module.gptq_hessian.data.to(module.weight.device), - quant_args, - blocksize=self.block_size, - percdamp=self.dampening_frac, - module_class=type(module), - ) + offloaded = is_module_offloaded(module) + if offloaded: + module._hf_hook.pre_forward(module) + + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + module.weight.data, + inp, + quant_args, + blocksize=self.block_size, + percdamp=self.dampening_frac, + module_class=type(module), + original_weight=module.original_weight.data, + ) + + delattr(module, "gptq_hessian") + delattr(module, "gptq_hessian_samples") - delattr(module, "gptq_hessian") - delattr(module, "gptq_hessian_samples") + # FUTURE: Implement learning rate modification to weight update - # FUTURE: Implement learning rate modification to weight update + if is_module_offloaded(module): + update_prefix_dict(self.layer, "weight", quantized_weight) + update_parameter_data(module, quantized_weight, "weight") + update_parameter_data(module, scale, "weight_scale") + update_parameter_data(module, zero_point, "weight_zero_point") + update_parameter_data(module, g_idx, "weight_g_idx") - if is_module_offloaded(module): - update_prefix_dict(self.layer, "weight", quantized_weight) - update_parameter_data(module, quantized_weight, "weight") - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - update_parameter_data(module, g_idx, "weight_g_idx") + if offloaded: + module._hf_hook.post_forward(module, None) return loss diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 4365ce3d3..5f4f0cd22 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,6 +1,6 @@ import math from copy import copy -from typing import Tuple, Union +from typing import Tuple, Union, Optional, Type import torch import transformers @@ -82,7 +82,8 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: def compute_scale_zero_point( - W: torch.Tensor, quant_args: QuantizationArgs + W: torch.Tensor, + quant_args: QuantizationArgs, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute the scale and zero point of a module weight @@ -103,17 +104,19 @@ def compute_scale_zero_point( def quantize_weight( weight: torch.Tensor, - H: torch.Tensor, #inp: torch.Tensor, + inp: torch.Tensor, quant_args: QuantizationArgs, blocksize: int = 128, percdamp: float = 0.01, - module_class=torch.nn.Linear, + module_class: Type[torch.nn.Module] = torch.nn.Linear, + original_weight: Optional[torch.Tensor] = None, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: """ Quantize a module weight according to the GPTQ algorithm + TODO :param weight: weight being quantized - # :param inp: module inputs used to calculate hessian + :param inp: module inputs used to calculate hessian :param quant_args: quantization arguments used to find quantization parameters :param blocksize: chunk size of quantization updates :param percdamp: dampening factor on hessian diagonal @@ -126,7 +129,7 @@ def quantize_weight( final_dtype = weight.dtype W = weight.data.clone() - #H = compute_hessian(inp, module_class, device=weight.device) + H = compute_hessian(inp, module_class, device=weight.device) # standardize shape and dtype if module_class == torch.nn.Conv2d: @@ -199,9 +202,9 @@ def quantize_weight( W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): - w = W1[:, i] + w = original_weight[:, i] d = Hinv1[i, i] - q = w.clone() + q = W1[:, i].clone() # quantize column if strategy == QuantizationStrategy.TENSOR: diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 322f7b787..414c75bc1 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -11,6 +11,7 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import add_batch from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException +from llmcompressor.utils.fsdp.helpers import register_offload_parameter from llmcompressor.utils.helpers import getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params @@ -99,11 +100,9 @@ class LayerCompressorMixin(HooksMixin): true_sequential: bool sequential_targets: bool - # compress_module: Callable[[str, torch.nn.Module, Tuple], float] _layer_index = 0 _num_layers = 0 - _pre_active: Set[torch.nn.Module] = set() _module_inputs: Dict[torch.nn.Module, List[Tuple[Tuple[Any, ...], Dict[str, Any]]]] = defaultdict(lambda: []) _module_outputs: Dict[torch.nn.Module, Union[List[Tuple[Any, ...]], torch.Tensor]] = defaultdict(lambda: []) @@ -137,6 +136,10 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) + #register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? + register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? + register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? + if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) post_hook = partial(self.layer_post_forward, name) @@ -166,6 +169,14 @@ def target_pre_forward( input ) + if self.true_sequential: + if module.gptq_hessian_samples >= 20: + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) @HooksMixin.hook def target_post_forward( @@ -176,14 +187,14 @@ def target_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - - if module.gptq_hessian_samples >= 20: - # compress - print(f"compressing {name}") - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) + if not self.true_sequential: + if module.gptq_hessian_samples >= 20: + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) """ breakpoint() @@ -223,30 +234,7 @@ def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs logger.info( f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) - - input = args[0] - - if not self.true_sequential: - self._module_inputs[layer] += [ - input[batch_index: batch_index + 1] - for batch_index in range(input.shape[0]) - ] - - - if len(self._module_outputs[layer]) >= 20 - 1: - # last sample can be passed normally - print("last layer forward") - return (input[-1:], *args[1:]), kwargs - - else: - forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) - for batch_index in range(input.size(0)): - print("layer forward") - output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) - self._module_outputs[layer].append(output) - - raise EarlyStopException(torch.tensor([]), None) - + @HooksMixin.hook def layer_post_forward( @@ -258,30 +246,9 @@ def layer_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - breakpoint() - - # capture last sample - self._module_outputs[layer].append(output) - - # batch outputs - outputs = self._module_outputs[layer] - batched_outputs = tuple( - torch.concat(tuple( - outputs[sample_index][output_index] - for sample_index in range(len(outputs)) - )) - for output_index in range(len(outputs[0])) - ) - del self._module_outputs[layer] - - if not self.true_sequential: - pass # run again - - del self._module_inputs[layer] - - return batched_outputs - if not self.true_sequential: + + if False and not self.true_sequential: # only print # rerun with (now) compressed weights with HooksMixin.disable_hooks(): compressed_output = layer(*args, **kwargs) diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 80ef733f1..e5ecda8c7 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,7 +1,9 @@ import contextlib +from functools import wraps import operator from pathlib import Path from typing import Optional +import warnings from loguru import logger @@ -23,6 +25,16 @@ from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe from llmcompressor.utils.pytorch import set_layer +try: + from accelerate.hooks import AlignDevicesHook + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset + _has_accelerate = True +except ImportError: + _has_accelerate = False + AlignDevicesHook = None + OffloadedWeightsLoader = None + PrefixedDataset = None + __all__ = [ "is_fsdp_model", "maybe_get_wrapped", @@ -183,32 +195,150 @@ def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: return parent +# upstream candidate def has_offloaded_params(module: torch.nn.Module) -> bool: """ Checks if a module has offloaded parameters by checking if the given module has a AlignDevicesHook attached with offloading enabled + Args: module (`torch.nn.Module`): The module to check for an offload hook. + Returns: bool: `True` if the module has an offload hook and offloading is enabled, `False` otherwise. """ - from accelerate.hooks import AlignDevicesHook - return ( hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload ) -@contextlib.contextmanager -def align_module( + +# depreciation candidate +@wraps(has_offloaded_params) +def is_module_offloaded(module: torch.nn.Module) -> bool: + if not _has_accelerate: + return False + + return has_offloaded_params(module) + + +# depreciation candidate +def get_execution_device(module: torch.nn.Module) -> torch.device: + """ + :param module: module to check + :return: device module is loaded onto during forward pass + """ + if is_module_offloaded(module): + return module._hf_hook.execution_device + device = next(module.parameters()).device + + # offload only gets set for leaf modules, fallback to checking for device type + if device.type == "meta": + return module._hf_hook.execution_device + + return device + + +# upstream candidate +def _infer_offload_device(module: torch.nn.Module) -> torch.device: + if not has_offloaded_params(module): + raise ValueError("Cannot infer offload device from non-offloaded module") + + first_key = next(module._hf_hook.weights_map.keys(), None) + if first_key is None: + raise ValueError("Cannot infer offload device from empty weights map") + + prefix_dataset = module._hf_hook.weights_map.dataset + return prefix_dataset[first_key].device + +# depreciation candidate +def get_offloaded_device(module: torch.nn.Module) -> torch.device: + """ + :param module: module to check + :return: device module is offloaded to onto after forward pass + """ + return _infer_offload_device(module) + + +# depreciation candidate +def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): + """ + Updates the offloaded state dict for a given module. Parameter named key is replaced + by data. This is neccesary because parameter updates for offloaded modules do not + persist automatically between loads. This function only affects the offloaded + state dict and not the current state of the loaded module. + + :param module: module containing the parameter to update + :param key: name of parameter to update + :param data: tensor to update parameter with in the offloaded state dict + """ + if not is_module_offloaded(module): + raise ValueError("Prefix dict is only applicable to offloaded modules") + prefix_dict = module._hf_hook.weights_map + prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data + + +# upstream candidate? +def update_offload_parameter( module: torch.nn.Module, - execution_device: Optional[torch.device] = None, - args = tuple(), kwargs = dict() + name: str, + data: torch.Tensor, + offload_device: Optional[torch.device] = None, ): + """ + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param offload_device: offload device for newly registered parameters + """ + if data.device == "meta": + raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") + + param = getattr(module, name) + if param.data.dtype != data.dtype: + warnings.warn("TODO") + param.data.copy_(data) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if prefix_dict is not None: + prefix = module._hf_hook.weights_map.prefix + key = f"{prefix}{name}" + + offload_device = ( + prefix_dict[key].device if key in prefix_dict + else offload_device if offload_device is not None + else _infer_offload_device(module) + ) + prefix_dict[key] = data.to(device=offload_device) + + if isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + else: + raise NotImplementedError() + +# depreciation candidate +def update_parameter_data( + module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str +): + param = getattr(module, param_name) + new_param_data = new_param_data.to(device=param.device, dtype=param.dtype) + update_offload_parameter(module, param_name, new_param_data) + + +# upstream candidate +@contextlib.contextmanager +def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): """ Move a module's parameters to the execution device + :param module: module with parameters to align :param execution_device: if provided, overrides module execution device within the context @@ -218,7 +348,7 @@ def align_module( original_device = module._hf_hook.execution_device module._hf_hook.execution_device = original_device - module._hf_hook.pre_forward(module, *args, **kwargs) + module._hf_hook.pre_forward(module) yield module._hf_hook.post_forward(module, None) @@ -240,35 +370,33 @@ def align_module( yield -def update_offload_parameter( +@contextlib.contextmanager +def modify_offload_module( module: torch.nn.Module, - name: str, - data: torch.Tensor, - init_device: Optional[torch.device] = torch.device("cpu"), + execution_device: Optional[torch.device] = None, + offload_device: Optional[torch.device] = None, ): - """ - :param module: module containing the parameter to update - :param name: name of module parameter to update - :param data: tensor to update parameter with - :param init_device: offload device for newly registered parameters - """ - param = getattr(module, name) - param.data = data - - prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) - if prefix_dict is not None: - prefix = module._hf_hook.weights_map.prefix - key = f"{prefix}{name}" + with align_module(module, execution_device): + yield - offload_device = prefix_dict[key].device if key in prefix_dict else init_device - prefix_dict[key] = data.to(device=offload_device) + # there is little performance gain from checking if a parameter's data + # has been modified before copying since the new data must be copied + # to the offload device anyways; just update all module parameters + for name, param in module.named_parameters(): + update_offload_parameter(module, name, param.data, offload_device) +# upstream candidate? def register_offload_parameter( module: torch.nn.Module, name: str, - data: torch.Tensor, - offload_device: Optional[torch.device] = torch.device("cpu"), + parameter: torch.nn.Parameter, + offload_device: Optional[torch.device] = None, ): - module.register_parameter(name, torch.nn.Parameter(data)) - update_offload_parameter(module, name, data, offload_device) \ No newline at end of file + module.register_parameter(name, parameter) + update_offload_parameter(module, name, parameter.data, offload_device) + + +# upstream candidate? +def deregister_offload_parameter(): + raise NotImplementedError() \ No newline at end of file From c4d2ddebb090262436cdd76f6e46399aafd73da2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 12:16:31 -0400 Subject: [PATCH 053/285] revert weird batching, support image text datasets --- .../modifiers/quantization/gptq/base.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 86 ++----------------- .../transformers/finetune/data/base.py | 20 +++-- .../transformers/finetune/text_generation.py | 4 +- 4 files changed, 25 insertions(+), 87 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 767664640..68d603a4b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -234,7 +234,7 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - #dataloader = create_single_batch_dataloader(dataloader.dataset) + dataloader = create_single_batch_dataloader(dataloader.dataset) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 322f7b787..e3546b22f 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -166,6 +166,13 @@ def target_pre_forward( input ) + if self.true_sequential: + # compress first so output is from compressed weights + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + @HooksMixin.hook def target_post_forward( @@ -177,46 +184,11 @@ def target_post_forward( ): print(f"post {name}") - if module.gptq_hessian_samples >= 20: - # compress - print(f"compressing {name}") - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - """ - breakpoint() - ret = torch.concat(self._module_outputs) - del self._module_inputs[module] - del self._module_outputs[module] - return ret - - # accumulate - self._module_outputs.append(output) - - if len(self._module_outputs) == 2: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - ret = self._module_outputs - self._module_outputs = [] - - return ret - - if self.true_sequential: - # compress first so output is from compressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - if not self.true_sequential: # compress after so output is from uncompressed weights with CompressionLogger(module) as comp_logger: loss = self.compress_module(name, module, args) comp_logger.set_loss(loss) - """ @HooksMixin.hook def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs): @@ -224,29 +196,6 @@ def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" ) - input = args[0] - - if not self.true_sequential: - self._module_inputs[layer] += [ - input[batch_index: batch_index + 1] - for batch_index in range(input.shape[0]) - ] - - - if len(self._module_outputs[layer]) >= 20 - 1: - # last sample can be passed normally - print("last layer forward") - return (input[-1:], *args[1:]), kwargs - - else: - forward_call = (layer._slow_forward if torch._C._get_tracing_state() else layer.forward) - for batch_index in range(input.size(0)): - print("layer forward") - output = forward_call(input[batch_index: batch_index + 1], *args[1:], **kwargs) - self._module_outputs[layer].append(output) - - raise EarlyStopException(torch.tensor([]), None) - @HooksMixin.hook def layer_post_forward( @@ -260,27 +209,6 @@ def layer_post_forward( print(f"post {name}") breakpoint() - # capture last sample - self._module_outputs[layer].append(output) - - # batch outputs - outputs = self._module_outputs[layer] - batched_outputs = tuple( - torch.concat(tuple( - outputs[sample_index][output_index] - for sample_index in range(len(outputs)) - )) - for output_index in range(len(outputs[0])) - ) - del self._module_outputs[layer] - - if not self.true_sequential: - pass # run again - - del self._module_inputs[layer] - - return batched_outputs - if not self.true_sequential: # rerun with (now) compressed weights with HooksMixin.disable_hooks(): diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index d4c3a6222..744313691 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -51,17 +51,17 @@ def __init__( self.padding = False if self.tokenizer: - if not self.tokenizer.pad_token: - self.tokenizer.pad_token = self.tokenizer.eos_token + if not self.tokenizer.tokenizer.pad_token: + self.tokenizer.tokenizer.pad_token = self.tokenizer.tokenizer.eos_token # configure sequence length max_seq_length = data_args.max_seq_length - model_max_length = tokenizer.model_max_length if tokenizer else max_seq_length + model_max_length = tokenizer.tokenizer.model_max_length if tokenizer else max_seq_length if self.tokenizer and max_seq_length > model_max_length: logger.warning( f"The max_seq_length passed ({max_seq_length}) is larger than " - f"the maximum length for the model ({tokenizer.model_max_length}). " - f"Using max_seq_length={tokenizer.model_max_length}." + f"the maximum length for the model ({tokenizer.tokenizer.model_max_length}). " + f"Using max_seq_length={tokenizer.tokenizer.model_max_length}." ) self.max_seq_length = min(data_args.max_seq_length, model_max_length) @@ -97,6 +97,7 @@ def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset: def tokenize_and_process( self, raw_dataset: Optional[Dataset] = None, add_labels: Optional[bool] = True ) -> Dataset: + breakpoint() """ Sets up the raw dataset for finetuning, performs tokenization, concatenates entries to max sequence length if desired, and adds labels to each entry @@ -107,6 +108,15 @@ def tokenize_and_process( # helper fn for tokenizing text column def tokenize_fn(data): + """ + inputs = processor( + image, + input_text, + add_special_tokens=False, + return_tensors="pt" + ).to(model.device) + """ + result = self.tokenizer( data[self.text_column], padding=self.padding, diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 1856ca954..46829e8dc 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -23,7 +23,7 @@ from loguru import logger from transformers import ( AutoConfig, - AutoTokenizer, + AutoProcessor, DefaultDataCollator, HfArgumentParser, set_seed, @@ -221,7 +221,7 @@ def initialize_model_from_path( def initialize_tokenizer_from_path(model_args, model, teacher): tokenizer_src = model_args.tokenizer tokenizer_src = tokenizer_src or get_shared_tokenizer_src(model, teacher) - tokenizer = AutoTokenizer.from_pretrained( + tokenizer = AutoProcessor.from_pretrained( tokenizer_src, cache_dir=model_args.cache_dir, use_fast=True, From 670b35e10c9b174049724fe51cd2af583cb97ef5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 12:17:13 -0400 Subject: [PATCH 054/285] remove breakpoint --- src/llmcompressor/modifiers/utils/hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index e3546b22f..5bddb9c56 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -207,7 +207,6 @@ def layer_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - breakpoint() if not self.true_sequential: # rerun with (now) compressed weights From 3892b907b8cf85213c96be43523e658171a1f293 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 12:17:31 -0400 Subject: [PATCH 055/285] add example script --- shubhra.py | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 shubhra.py diff --git a/shubhra.py b/shubhra.py new file mode 100644 index 000000000..229ffa78e --- /dev/null +++ b/shubhra.py @@ -0,0 +1,95 @@ +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot, wrap_hf_model_class +import os + +# Load model. +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +model_class = wrap_hf_model_class(MllamaForConditionalGeneration) +model = model_class.from_pretrained(model_id, device_map="auto", torch_dtype="auto", trust_remote_code=True, _attn_implementation="eager",) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +print("Loading dataset") +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:128]" + +NUM_CALIBRATION_SAMPLES = 1#128 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +print("Preprocessing samples") +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + tmp = processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt") + for key in tmp: + tmp[key] = tmp[key].squeeze(0) + + return tmp + + +ds = ds.map(tokenize, remove_columns=ds.column_names) +print(ds) + +print("Setting up quantization params") +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] +#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] + +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore), +] + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, +) + +#processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") From 03515f0ecfab91061e2b60fb338ed1b1a898533f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 17:52:12 +0000 Subject: [PATCH 056/285] remove hessian --- .../modifiers/quantization/gptq/base.py | 18 +--- src/llmcompressor/modifiers/utils/hooks.py | 85 ++----------------- 2 files changed, 12 insertions(+), 91 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7e6e5556d..119179666 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -26,7 +26,7 @@ from llmcompressor.transformers.finetune.data.data_helpers import ( create_single_batch_dataloader, ) -from llmcompressor.utils.fsdp.helpers import has_offloaded_params, register_offload_parameter +from llmcompressor.utils.fsdp.helpers import has_offloaded_params from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, @@ -259,10 +259,7 @@ def compress_module( inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - + with align_module(module): loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, inp, @@ -273,21 +270,14 @@ def compress_module( original_weight=module.original_weight.data, ) - delattr(module, "gptq_hessian") - delattr(module, "gptq_hessian_samples") - - # FUTURE: Implement learning rate modification to weight update + #weight_update_acc = module.weight_update_acc.data + quantized_weight + #update_parameter_data(module, quantized_weight, "weight") - if is_module_offloaded(module): - update_prefix_dict(self.layer, "weight", quantized_weight) update_parameter_data(module, quantized_weight, "weight") update_parameter_data(module, scale, "weight_scale") update_parameter_data(module, zero_point, "weight_zero_point") update_parameter_data(module, g_idx, "weight_g_idx") - if offloaded: - module._hf_hook.post_forward(module, None) - return loss def _build_quant_modifier(self): diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 414c75bc1..3bda3da2c 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -137,8 +137,8 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_hook(post_hook)) #register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? - register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? - register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? + #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? + #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) @@ -153,30 +153,12 @@ def register_hooks(self, model: torch.nn.Module): def target_pre_forward( self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): - input = args[0] - - # compute hessian - if not hasattr(module, "gptq_hessian"): - num_columns = module.weight.shape[1] - module.gptq_hessian = torch.zeros((num_columns, num_columns), dtype=torch.float32, device=input.device) - module.gptq_hessian_samples = 0 - - print(f"{name} adding {input.size(0)} samples") - module.gptq_hessian, module.gptq_hessian_samples = add_batch( - module.gptq_hessian, - module.gptq_hessian_samples, - module, - input - ) - - if self.true_sequential: - if module.gptq_hessian_samples >= 20: - # compress - print(f"compressing {name}") - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) @HooksMixin.hook def target_post_forward( @@ -187,47 +169,6 @@ def target_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - if not self.true_sequential: - if module.gptq_hessian_samples >= 20: - # compress - print(f"compressing {name}") - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - """ - breakpoint() - ret = torch.concat(self._module_outputs) - del self._module_inputs[module] - del self._module_outputs[module] - return ret - - # accumulate - self._module_outputs.append(output) - - if len(self._module_outputs) == 2: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - ret = self._module_outputs - self._module_outputs = [] - - return ret - - if self.true_sequential: - # compress first so output is from compressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - - if not self.true_sequential: - # compress after so output is from uncompressed weights - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) - """ @HooksMixin.hook def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs): @@ -246,15 +187,5 @@ def layer_post_forward( output: Tuple[Any, ...], ): print(f"post {name}") - - - if False and not self.true_sequential: # only print - # rerun with (now) compressed weights - with HooksMixin.disable_hooks(): - compressed_output = layer(*args, **kwargs) - - error = torch.nn.functional.l1_loss(output[0], compressed_output[0]) - logger.info(f"Mean output error from quantization: {error:.3f}") - self._layer_index += 1 return output From 6e37f649e63c81e402389c5a009b74bf70be6eb9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 17:54:00 +0000 Subject: [PATCH 057/285] allocated original weight --- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 +- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 4 ++-- src/llmcompressor/modifiers/utils/hooks.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 119179666..6026fe530 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -267,7 +267,7 @@ def compress_module( blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - original_weight=module.original_weight.data, + weight_original=module.weight_original.data, ) #weight_update_acc = module.weight_update_acc.data + quantized_weight diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 5f4f0cd22..617572391 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -109,7 +109,7 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class: Type[torch.nn.Module] = torch.nn.Linear, - original_weight: Optional[torch.Tensor] = None, + weight_original: Optional[torch.Tensor] = None, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: """ Quantize a module weight according to the GPTQ algorithm @@ -202,7 +202,7 @@ def quantize_weight( W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): - w = original_weight[:, i] + w = weight_original[:, i] d = Hinv1[i, i] q = W1[:, i].clone() diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 3bda3da2c..f3d674538 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -136,7 +136,7 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) - #register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? + register_offload_parameter(module, "weight_original", module.weight.clone()) # TODO: better name? #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? From 09dae14c7e661fecdef2f8b90cd01e1266733c03 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 17:56:18 +0000 Subject: [PATCH 058/285] proper clone --- src/llmcompressor/modifiers/utils/hooks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index f3d674538..7c8858cbc 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -136,7 +136,8 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) - register_offload_parameter(module, "weight_original", module.weight.clone()) # TODO: better name? + breakpoint() + register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? From 944601e06a2c8279e1bc4e80f4157c4a1405d225 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 17:56:43 +0000 Subject: [PATCH 059/285] remove breakpoint --- src/llmcompressor/modifiers/utils/hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 7c8858cbc..05f7dfa3c 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -136,7 +136,6 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) - breakpoint() register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? From adbcee8ccc407d4d03c20a97b5fea467831611a1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:43:49 +0000 Subject: [PATCH 060/285] naive_update option --- .../modifiers/quantization/gptq/base.py | 72 +++++++++++++++---- src/llmcompressor/modifiers/utils/hooks.py | 13 +++- .../finetune/data/data_helpers.py | 8 ++- src/llmcompressor/utils/fsdp/helpers.py | 11 +-- 4 files changed, 79 insertions(+), 25 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 6026fe530..c26672422 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import math from compressed_tensors.quantization import ( QuantizationScheme, freeze_module_quantization, @@ -24,14 +25,18 @@ from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.transformers.finetune.data.data_helpers import ( - create_single_batch_dataloader, + create_batch_dataloader, ) -from llmcompressor.utils.fsdp.helpers import has_offloaded_params +from llmcompressor.utils.fsdp.helpers import has_offloaded_params, register_offload_parameter, update_offload_parameter from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, getattr_chain, ) +from compressed_tensors.quantization import ( + fake_quantize, +) + from llmcompressor.utils.pytorch.module import qat_active __all__ = ["GPTQModifier"] @@ -75,10 +80,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): :param sequential_update: Whether or not to update weights sequentially by layer. This option is depreciated and setting to False is no longer supported - :param true_sequential: Used to control the granularity of compression updates - through the forward pass. Set to True to use the weight-compressed outputs - of each module, set to False to use the weight-compressed outputs of each - layer (transformer block), defaults to False + :param naive_update: TODO :param sequential_targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass @@ -109,7 +111,8 @@ class GPTQModifier(Modifier, LayerCompressorMixin): """ sequential_update: bool = True # DEPRECIATED - true_sequential: bool = False + naive_update: bool = False + batch_size: int = 1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 @@ -124,6 +127,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): disable_quantization_observer_epoch: Optional[float] = None _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() + _num_batches: int = PrivateAttr() @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -201,6 +205,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") + + self._num_batches = math.ceil(len(state.data.calib.dataset) / self.batch_size) self.register_hooks(state.model) self.calibration_forward(state.model, state.data.calib) @@ -222,6 +228,31 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self._quantization_modifier: self._quantization_modifier.finalize(state, **kwargs) + for module in state.model.modules(): + with align_module(module): + quant_args = getattr_chain(module, "quantization_scheme.weights", None) + if quant_args is None: + continue + + if self.naive_update: + weight = module.weight_acc / self._num_batches + delattr(module, "weight_acc") + + if self.naive_update: + weight = module.weight + delattr(module, "weight_original") + + scale, zero_point = quant_args.get_observer()(weight) + weight = fake_quantize( + weight, + scale, + zero_point, + quant_args, + ) + update_offload_parameter(module, "weight", weight) + update_offload_parameter(module, "scale", scale) + update_offload_parameter(module, "zero_point", zero_point) + return True def calibration_forward( @@ -234,10 +265,18 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - #dataloader = create_single_batch_dataloader(dataloader.dataset) + dataloader = create_batch_dataloader(dataloader.dataset, batch_size=self.batch_size) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) + def pre_compress_module(self, module: torch.nn.Module): + # TODO: better names? + if self.naive_update: + register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) + + else: + register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) + def compress_module( self, name: str, @@ -267,16 +306,19 @@ def compress_module( blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - weight_original=module.weight_original.data, + weight_original=module.weight_original.data if self.naive_update else module.weight.data ) - #weight_update_acc = module.weight_update_acc.data + quantized_weight - #update_parameter_data(module, quantized_weight, "weight") + if self.naive_update: + module.weight_acc += quantized_weight + update_offload_parameter(module, "weight_acc") + else: + module.weight += (quantized_weight - module.weight) * self._num_batches + update_offload_parameter(module, "weight") - update_parameter_data(module, quantized_weight, "weight") - update_parameter_data(module, scale, "weight_scale") - update_parameter_data(module, zero_point, "weight_zero_point") - update_parameter_data(module, g_idx, "weight_g_idx") + scale, zero_point = quant_args.get_observer()(module.weight) + update_offload_parameter(module, "scale", scale) + update_offload_parameter(module, "zero_point", zero_point) return loss diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 05f7dfa3c..bf03f6e86 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -109,6 +109,15 @@ class LayerCompressorMixin(HooksMixin): _layer_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] _layer_outputs: List[Tuple[Any, ...]] = [] + @abstractmethod + def pre_compress_module( + self, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + ) -> float: + raise NotImplementedError() + @abstractmethod def compress_module( self, @@ -136,9 +145,7 @@ def register_hooks(self, model: torch.nn.Module): self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) self.register_hook(module.register_forward_hook(post_hook)) - register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) # TODO: better name? - #register_offload_parameter(module, "weight_update_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) # TODO: better name? - #register_offload_parameter(module, "num_samples", torch.nn.Parameter(torch.tensor(0.0), requires_grad=False)) # TODO: better name? + self.pre_compress_module(module) if name in layers.keys(): pre_hook = partial(self.layer_pre_forward, name) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index cc1c946ac..6f336aa8b 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -13,7 +13,7 @@ LABELS_MASK_VALUE = -100 __all__ = [ - "create_single_batch_dataloader", + "create_batch_dataloader", "format_calibration_data", "get_raw_dataset", "make_dataset_splits", @@ -22,13 +22,15 @@ ] -def create_single_batch_dataloader( +def create_batch_dataloader( dataset: datasets.Dataset, + batch_size: int, ) -> torch.utils.data.DataLoader: """ Create a dataloader whose batch size is equal to the size of the dataset :param dataset: dataset used to generate dataloader + :param batch_size: batch size of new dataloader :return: dataloader """ @@ -49,7 +51,7 @@ def pad_sequences(batch): return torch.utils.data.DataLoader( dataset, - batch_size=len(dataset), + batch_size=batch_size, shuffle=True, collate_fn=pad_sequences, pin_memory=True, diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index e5ecda8c7..2707cbae8 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -284,7 +284,7 @@ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): def update_offload_parameter( module: torch.nn.Module, name: str, - data: torch.Tensor, + data: Optional[torch.Tensor] = None, offload_device: Optional[torch.device] = None, ): """ @@ -297,9 +297,12 @@ def update_offload_parameter( raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") param = getattr(module, name) - if param.data.dtype != data.dtype: - warnings.warn("TODO") - param.data.copy_(data) + if data is None: + data = param.data + else: + if param.data.dtype != data.dtype: + warnings.warn("TODO") + param.data.copy_(data) if has_offloaded_params(module): weights_map = module._hf_hook.weights_map From f4acab20dcb032662f41ec1acbdce4632e94cc99 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:45:29 +0000 Subject: [PATCH 061/285] remove true sequential --- src/llmcompressor/modifiers/utils/hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index bf03f6e86..dce60a6a5 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -98,7 +98,6 @@ class LayerCompressorMixin(HooksMixin): :ivar compresss_module: Function to be called on target modules """ - true_sequential: bool sequential_targets: bool _layer_index = 0 From 151f566730645f820b9fcd9dac6c006ef4ed7595 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:48:09 +0000 Subject: [PATCH 062/285] allow update_offload_parameter to not require data --- src/llmcompressor/utils/fsdp/helpers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 2707cbae8..5b60b68e2 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -293,15 +293,14 @@ def update_offload_parameter( :param data: tensor to update parameter with :param offload_device: offload device for newly registered parameters """ - if data.device == "meta": - raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") - param = getattr(module, name) - if data is None: - data = param.data - else: + if data is not None: + if data.device == "meta": + raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") + if param.data.dtype != data.dtype: warnings.warn("TODO") + param.data.copy_(data) if has_offloaded_params(module): @@ -319,7 +318,7 @@ def update_offload_parameter( else offload_device if offload_device is not None else _infer_offload_device(module) ) - prefix_dict[key] = data.to(device=offload_device) + prefix_dict[key] = param.data.to(device=offload_device) if isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() From 76ebc8609c05ced749784f3b12e6df2cbc05eb45 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:49:25 +0000 Subject: [PATCH 063/285] bugfix --- src/llmcompressor/modifiers/quantization/gptq/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c26672422..4a22c6f9b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -111,7 +111,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): """ sequential_update: bool = True # DEPRECIATED - naive_update: bool = False + naive_update: bool = True batch_size: int = 1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -250,8 +250,8 @@ def on_finalize(self, state: "State", **kwargs) -> bool: quant_args, ) update_offload_parameter(module, "weight", weight) - update_offload_parameter(module, "scale", scale) - update_offload_parameter(module, "zero_point", zero_point) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) return True @@ -317,8 +317,8 @@ def compress_module( update_offload_parameter(module, "weight") scale, zero_point = quant_args.get_observer()(module.weight) - update_offload_parameter(module, "scale", scale) - update_offload_parameter(module, "zero_point", zero_point) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) return loss From 3480d6b75c0df8c2de4d8c687a18d1cc8ee262d3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 20:50:59 +0000 Subject: [PATCH 064/285] ba --- src/llmcompressor/modifiers/quantization/gptq/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 4a22c6f9b..c90574b2a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -238,7 +238,7 @@ def on_finalize(self, state: "State", **kwargs) -> bool: weight = module.weight_acc / self._num_batches delattr(module, "weight_acc") - if self.naive_update: + else: weight = module.weight delattr(module, "weight_original") @@ -306,7 +306,7 @@ def compress_module( blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - weight_original=module.weight_original.data if self.naive_update else module.weight.data + weight_original=module.weight.data if self.naive_update else module.weight_original.data ) if self.naive_update: From 7c55fc596d14eba258b366fbf7032c7ba5a26bfd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 29 Oct 2024 19:12:30 -0400 Subject: [PATCH 065/285] delete parameter --- .../modifiers/quantization/gptq/base.py | 6 +-- src/llmcompressor/utils/fsdp/helpers.py | 51 +++++++++++++++---- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c90574b2a..92788dce1 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -27,7 +27,7 @@ from llmcompressor.transformers.finetune.data.data_helpers import ( create_batch_dataloader, ) -from llmcompressor.utils.fsdp.helpers import has_offloaded_params, register_offload_parameter, update_offload_parameter +from llmcompressor.utils.fsdp.helpers import delete_offload_parameter, register_offload_parameter, update_offload_parameter from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, @@ -236,11 +236,11 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self.naive_update: weight = module.weight_acc / self._num_batches - delattr(module, "weight_acc") + delete_offload_parameter(module, "weight_acc") else: weight = module.weight - delattr(module, "weight_original") + delete_offload_parameter(module, "weight_original") scale, zero_point = quant_args.get_observer()(weight) weight = fake_quantize( diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 5b60b68e2..e58b4f1c3 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -27,7 +27,7 @@ try: from accelerate.hooks import AlignDevicesHook - from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset, set_module_tensor_to_device _has_accelerate = True except ImportError: _has_accelerate = False @@ -339,16 +339,20 @@ def update_parameter_data( @contextlib.contextmanager def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): """ - Move a module's parameters to the execution device + Moves a module's parameters to the specified execution device. - :param module: module with parameters to align - :param execution_device: if provided, overrides module execution device - within the context + Args: + module (torch.nn.Module): Module with parameters to align. + execution_device (Optional[torch.device]): If provided, overrides the + module's execution device within the context. + + Yields: + None: Yields control while the module's parameters are aligned to the execution device. """ if has_offloaded_params(module): if execution_device is not None: original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = original_device + module._hf_hook.execution_device = execution_device module._hf_hook.pre_forward(module) yield @@ -361,17 +365,26 @@ def align_module(module: torch.nn.Module, execution_device: Optional[torch.devic devices = {} for name, param in module.named_parameters(): devices[name] = param.device - setattr(module, name, param.to(execution_device)) + set_module_tensor_to_device( + module, + name, + execution_device, + ) yield - for name, param_device in module.named_parameters: - setattr(module, name, param.to(param_device)) + for name, param in module.named_parameters(): + set_module_tensor_to_device( + module, + name, + devices[name], + ) else: yield + @contextlib.contextmanager def modify_offload_module( module: torch.nn.Module, @@ -400,5 +413,21 @@ def register_offload_parameter( # upstream candidate? -def deregister_offload_parameter(): - raise NotImplementedError() \ No newline at end of file +def delete_offload_parameter(module: torch.nn.Module, name: str): + delattr(module, name) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + prefix = weights_map.prefix + if dataset is not None: + del dataset[f"{prefix}{name}"] + + elif isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + elif weights_map is not None: + raise NotImplementedError(f"Cannot delete parameter from weights_map of type {type(weights_map)}") \ No newline at end of file From 0a8004b00725451a43b40677ddb0fbbe3268e04c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Oct 2024 15:01:01 -0400 Subject: [PATCH 066/285] sensible generations for small calibration size --- .../modifiers/quantization/gptq/base.py | 57 +++++++++++-------- .../quantization/gptq/utils/gptq_quantize.py | 33 +++-------- .../finetune/data/data_helpers.py | 6 +- 3 files changed, 43 insertions(+), 53 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 92788dce1..c1e535215 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -7,20 +7,12 @@ QuantizationScheme, freeze_module_quantization, ) -from compressed_tensors.utils import ( - is_module_offloaded, - update_parameter_data, - update_prefix_dict, -) from loguru import logger from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory -from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( - add_batch, - quantize_weight, -) +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -112,7 +104,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): sequential_update: bool = True # DEPRECIATED naive_update: bool = True - batch_size: int = 1 + batch_size: int = -1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 @@ -139,6 +131,16 @@ def validate_sequential_update(cls, value: bool) -> bool: ) return True + + @field_validator("naive_update", mode="before") + def validate_naive_update(cls, value: bool) -> bool: + if not value: + raise ValueError( + "`naive_update=False` is not implemented yet, please use " + "`naive_update=True`" + ) + + return True def on_initialize_structure(self, state: State, **kwargs): """ @@ -206,29 +208,23 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") + if self.batch_size <= 0: + self.batch_size = len(state.data.calib.dataset) self._num_batches = math.ceil(len(state.data.calib.dataset) / self.batch_size) self.register_hooks(state.model) self.calibration_forward(state.model, state.data.calib) - #state.model(**state.model.dummy_inputs) self.remove_hooks() + self.finish_compression() # freeze quantization state.model.apply(freeze_module_quantization) return True - - def on_finalize(self, state: "State", **kwargs) -> bool: - """ - disable the quantization observers used by the OBCQ algorithm - - :param state: session state storing input model and calibration data - """ - if self._quantization_modifier: - self._quantization_modifier.finalize(state, **kwargs) - - for module in state.model.modules(): + + def finish_compression(self, model: torch.nn.Module): + for module in model.modules(): with align_module(module): quant_args = getattr_chain(module, "quantization_scheme.weights", None) if quant_args is None: @@ -253,6 +249,15 @@ def on_finalize(self, state: "State", **kwargs) -> bool: update_offload_parameter(module, "weight_scale", scale) update_offload_parameter(module, "weight_zero_point", zero_point) + def on_finalize(self, state: "State", **kwargs) -> bool: + """ + disable the quantization observers used by the OBCQ algorithm + + :param state: session state storing input model and calibration data + """ + if self._quantization_modifier: + self._quantization_modifier.finalize(state, **kwargs) + return True def calibration_forward( @@ -265,7 +270,7 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - dataloader = create_batch_dataloader(dataloader.dataset, batch_size=self.batch_size) + dataloader = create_batch_dataloader(dataloader, batch_size=self.batch_size) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) @@ -297,6 +302,7 @@ def compress_module( # Assume that first argument is the input inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") + logger.info(f"Using {inp.size(0)} samples") with align_module(module): loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( @@ -306,19 +312,20 @@ def compress_module( blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - weight_original=module.weight.data if self.naive_update else module.weight_original.data + weight_original=None if self.naive_update else module.weight_original.data ) if self.naive_update: module.weight_acc += quantized_weight update_offload_parameter(module, "weight_acc") else: - module.weight += (quantized_weight - module.weight) * self._num_batches + module.weight += (quantized_weight - module.weight) / self._num_batches update_offload_parameter(module, "weight") scale, zero_point = quant_args.get_observer()(module.weight) update_offload_parameter(module, "weight_scale", scale) update_offload_parameter(module, "weight_zero_point", zero_point) + update_offload_parameter(module, "weight_g_idx", g_idx) # NOT SURE IF THIS IS CORRECT return loss diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 617572391..a2354a242 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -17,29 +17,7 @@ GPTQ_PRECISION = torch.float32 -def add_batch(H: torch.Tensor, nsamples: int , module: torch.nn.Module, inp: torch.Tensor): - """ - Add a batch of layer input and output data to the Hessian calculation - """ - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(module, torch.nn.Linear) or isinstance( - module, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - H *= nsamples / (nsamples + tmp) - nsamples += tmp - inp = inp.to(dtype=H.dtype) - inp = math.sqrt(2 / nsamples) * inp - H += inp.matmul(inp.t()) - - return H, nsamples - - -def compute_hessian(inp: torch.Tensor, module_class, device) -> torch.Tensor: +def compute_hessian(inp: torch.Tensor, module_class: Type[torch.nn.Module], device) -> torch.Tensor: """ Calculate the hessian with respect to the module inputs @@ -129,7 +107,8 @@ def quantize_weight( final_dtype = weight.dtype W = weight.data.clone() - H = compute_hessian(inp, module_class, device=weight.device) + if weight_original is not None: + raise NotImplementedError() # standardize shape and dtype if module_class == torch.nn.Conv2d: @@ -140,6 +119,8 @@ def quantize_weight( num_rows = W.shape[0] num_columns = W.shape[1] + H = compute_hessian(inp, module_class, device=weight.device) + if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index g_idx = ( @@ -202,9 +183,9 @@ def quantize_weight( W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): - w = weight_original[:, i] + w = W1[:, i] d = Hinv1[i, i] - q = W1[:, i].clone() + q = w.clone() # quantize column if strategy == QuantizationStrategy.TENSOR: diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 6f336aa8b..92d73edc2 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -23,7 +23,7 @@ def create_batch_dataloader( - dataset: datasets.Dataset, + dataloader: torch.utils.data.DataLoader, batch_size: int, ) -> torch.utils.data.DataLoader: """ @@ -33,6 +33,8 @@ def create_batch_dataloader( :param batch_size: batch size of new dataloader :return: dataloader """ + dataset = dataloader.dataset + sampler = dataloader.sampler.__class__(dataset) def pad_sequences(batch): # extract input_ids and attention_mask from the batch @@ -52,7 +54,7 @@ def pad_sequences(batch): return torch.utils.data.DataLoader( dataset, batch_size=batch_size, - shuffle=True, + sampler=sampler, collate_fn=pad_sequences, pin_memory=True, ) From d234b322df80dc565253384473a18b1c000c7f14 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Oct 2024 15:13:04 -0400 Subject: [PATCH 067/285] remove unnecessary variables --- src/llmcompressor/modifiers/utils/hooks.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index dce60a6a5..44e2cd8a7 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -102,11 +102,6 @@ class LayerCompressorMixin(HooksMixin): _layer_index = 0 _num_layers = 0 - _module_inputs: Dict[torch.nn.Module, List[Tuple[Tuple[Any, ...], Dict[str, Any]]]] = defaultdict(lambda: []) - _module_outputs: Dict[torch.nn.Module, Union[List[Tuple[Any, ...]], torch.Tensor]] = defaultdict(lambda: []) - - _layer_inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] - _layer_outputs: List[Tuple[Any, ...]] = [] @abstractmethod def pre_compress_module( From eeb5c8316400540da4a0010f4ca0658bb0d02c62 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 30 Oct 2024 15:25:05 -0400 Subject: [PATCH 068/285] remove non-naive updating stuff to focus on naive updating --- .../modifiers/quantization/gptq/base.py | 46 +++---------------- src/llmcompressor/modifiers/utils/hooks.py | 3 -- 2 files changed, 7 insertions(+), 42 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c1e535215..471340a26 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -72,7 +72,6 @@ class GPTQModifier(Modifier, LayerCompressorMixin): :param sequential_update: Whether or not to update weights sequentially by layer. This option is depreciated and setting to False is no longer supported - :param naive_update: TODO :param sequential_targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass @@ -103,7 +102,6 @@ class GPTQModifier(Modifier, LayerCompressorMixin): """ sequential_update: bool = True # DEPRECIATED - naive_update: bool = True batch_size: int = -1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 @@ -131,16 +129,6 @@ def validate_sequential_update(cls, value: bool) -> bool: ) return True - - @field_validator("naive_update", mode="before") - def validate_naive_update(cls, value: bool) -> bool: - if not value: - raise ValueError( - "`naive_update=False` is not implemented yet, please use " - "`naive_update=True`" - ) - - return True def on_initialize_structure(self, state: State, **kwargs): """ @@ -216,7 +204,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.calibration_forward(state.model, state.data.calib) self.remove_hooks() - self.finish_compression() + self.finish_compression(state.model) # freeze quantization state.model.apply(freeze_module_quantization) @@ -230,13 +218,8 @@ def finish_compression(self, model: torch.nn.Module): if quant_args is None: continue - if self.naive_update: - weight = module.weight_acc / self._num_batches - delete_offload_parameter(module, "weight_acc") - - else: - weight = module.weight - delete_offload_parameter(module, "weight_original") + weight = module.weight_acc / self._num_batches + delete_offload_parameter(module, "weight_acc") scale, zero_point = quant_args.get_observer()(weight) weight = fake_quantize( @@ -275,12 +258,7 @@ def calibration_forward( run_calibration_forward(model, dataloader, mask_padding=True) def pre_compress_module(self, module: torch.nn.Module): - # TODO: better names? - if self.naive_update: - register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) - - else: - register_offload_parameter(module, "weight_original", torch.nn.Parameter(module.weight.data.clone(), requires_grad=False)) + register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) def compress_module( self, @@ -305,27 +283,17 @@ def compress_module( logger.info(f"Using {inp.size(0)} samples") with align_module(module): - loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + loss, quantized_weight, _scale, _zero_point, _g_idx = quantize_weight( module.weight.data, inp, quant_args, blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), - weight_original=None if self.naive_update else module.weight_original.data ) - if self.naive_update: - module.weight_acc += quantized_weight - update_offload_parameter(module, "weight_acc") - else: - module.weight += (quantized_weight - module.weight) / self._num_batches - update_offload_parameter(module, "weight") - - scale, zero_point = quant_args.get_observer()(module.weight) - update_offload_parameter(module, "weight_scale", scale) - update_offload_parameter(module, "weight_zero_point", zero_point) - update_offload_parameter(module, "weight_g_idx", g_idx) # NOT SURE IF THIS IS CORRECT + module.weight_acc += quantized_weight + update_offload_parameter(module, "weight_acc") return loss diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 44e2cd8a7..9dcbac9d3 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -9,9 +9,6 @@ from torch.utils.hooks import RemovableHandle from collections import defaultdict -from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import add_batch -from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException -from llmcompressor.utils.fsdp.helpers import register_offload_parameter from llmcompressor.utils.helpers import getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params From c7c8d04aad57a0f74672440e3245ddfc9264d790 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 1 Nov 2024 13:38:56 -0400 Subject: [PATCH 069/285] use observer to calculate qparams --- .../quantization/gptq/utils/gptq_quantize.py | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index a2354a242..a625e8a7b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -2,6 +2,7 @@ from copy import copy from typing import Tuple, Union, Optional, Type +from llmcompressor.observers.base import Observer import torch import transformers from compressed_tensors.quantization import ( @@ -59,27 +60,6 @@ def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: return H -def compute_scale_zero_point( - W: torch.Tensor, - quant_args: QuantizationArgs, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Compute the scale and zero point of a module weight - TODO: revisit after observers refactor - - :param W: module weight - :param quant_args: quantization arguments which determine how quantization - parameters are calculated - :return: scale and zero_point - """ - # TODO: revisit after observers refactor - - scale, zero_point = quant_args.get_observer()(W, g_idx=None) - scale = scale.to(dtype=W.dtype) - zero_point = zero_point.to(dtype=quant_args.pytorch_dtype()) - return scale, zero_point - - def quantize_weight( weight: torch.Tensor, inp: torch.Tensor, @@ -107,6 +87,13 @@ def quantize_weight( final_dtype = weight.dtype W = weight.data.clone() + # create observer for calculating quantization parameters + observer = Observer.load_from_registry( + "minmax", + quantization_args=quant_args, + averaging_constant=1.0, # ignore moving average + ) + if weight_original is not None: raise NotImplementedError() @@ -131,22 +118,22 @@ def quantize_weight( if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, H, perm = _apply_activation_ordering(W, H) - scale, zero_point = compute_scale_zero_point(W, quant_args) + scale, zero_point = observer(W, g_idx=None) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - scale, zero_point = compute_scale_zero_point(W, quant_args) + scale, zero_point = observer(W, g_idx=None) W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] else: - scale, zero_point = compute_scale_zero_point(W, quant_args) + scale, zero_point = observer(W, g_idx=None) else: - scale, zero_point = compute_scale_zero_point(W, quant_args) + scale, zero_point = observer(W, g_idx=None) # sparsity mask sparsity = tensor_sparsity(W) From 2beb59a2f436d651928d3d5365d97d4ab2f732fe Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Nov 2024 16:34:17 +0000 Subject: [PATCH 070/285] remove tokenizer args --- src/llmcompressor/transformers/finetune/text_generation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 46829e8dc..e055727e7 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -223,10 +223,6 @@ def initialize_tokenizer_from_path(model_args, model, teacher): tokenizer_src = tokenizer_src or get_shared_tokenizer_src(model, teacher) tokenizer = AutoProcessor.from_pretrained( tokenizer_src, - cache_dir=model_args.cache_dir, - use_fast=True, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, trust_remote_code=model_args.trust_remote_code_model, ) From 4a336fe013cf796088563a9e674ee631527bfe90 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Nov 2024 18:53:28 +0000 Subject: [PATCH 071/285] fix shapes --- shubhra.py | 20 +++++++++---------- .../finetune/data/data_helpers.py | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/shubhra.py b/shubhra.py index 229ffa78e..f404ba4e0 100644 --- a/shubhra.py +++ b/shubhra.py @@ -1,13 +1,17 @@ from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot, wrap_hf_model_class import os +from accelerate import init_empty_weights +#os.environ["CUDA_VISIBLE_DEVICES"] = "" # Load model. -model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -model_class = wrap_hf_model_class(MllamaForConditionalGeneration) +#model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +model_id = "mgoin/pixtral-12b" +#with init_empty_weights(): +model_class = wrap_hf_model_class(LlavaForConditionalGeneration) model = model_class.from_pretrained(model_id, device_map="auto", torch_dtype="auto", trust_remote_code=True, _attn_implementation="eager",) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) @@ -15,7 +19,7 @@ DATASET_ID = "lmms-lab/flickr30k" DATASET_SPLIT = "test[:128]" -NUM_CALIBRATION_SAMPLES = 1#128 +NUM_CALIBRATION_SAMPLES = 2 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -47,11 +51,7 @@ def preprocess(example): # Tokenize inputs. def tokenize(sample): - tmp = processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt") - for key in tmp: - tmp[key] = tmp[key].squeeze(0) - - return tmp + return processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt", max_length=MAX_SEQUENCE_LENGTH) ds = ds.map(tokenize, remove_columns=ds.column_names) @@ -68,7 +68,7 @@ def tokenize(sample): recipe = [ # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, sequential_targets=["MistralDecoderLayer"]), ] save_name = model_id.split("/")[1] + "-W8A8" diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index cc1c946ac..d9beaf60f 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -34,8 +34,8 @@ def create_single_batch_dataloader( def pad_sequences(batch): # extract input_ids and attention_mask from the batch - input_ids = [torch.tensor(item["input_ids"]) for item in batch] - masks = [torch.tensor(item["attention_mask"]) for item in batch] + input_ids = [torch.tensor(item["input_ids"]).squeeze(0) for item in batch] + masks = [torch.tensor(item["attention_mask"]).squeeze(0) for item in batch] # while 0 is not necessarily the "correct" padding value, the padded # input_ids are ignored according to the attention_mask From f1373478a31d567997a3d65c56bbb2e9273b7586 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Nov 2024 16:20:56 -0500 Subject: [PATCH 072/285] complete, more or less --- .../modifiers/quantization/gptq/base.py | 54 ++++++++++++++----- src/llmcompressor/modifiers/utils/hooks.py | 6 +++ .../modifiers/utils/pytorch_helpers.py | 5 +- src/llmcompressor/utils/helpers.py | 2 +- 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c7ef93948..91ec2ed2f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -16,6 +16,7 @@ from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.observers.base import Observer from llmcompressor.transformers.finetune.data.data_helpers import ( create_batch_dataloader, ) @@ -197,11 +198,20 @@ def on_initialize(self, state: "State", **kwargs) -> bool: raise ValueError("To use the GPTQModifier, quantization must be enabled.") if self.batch_size <= 0: - self.batch_size = len(state.data.calib.dataset) - self._num_batches = math.ceil(len(state.data.calib.dataset) / self.batch_size) + batch_size = len(state.data.calib.dataset) + else: + batch_size = self.batch_size + self._num_batches = math.ceil(len(state.data.calib.dataset) / batch_size) self.register_hooks(state.model) - self.calibration_forward(state.model, state.data.calib) + #torch.cuda.memory._record_memory_history(max_entries=1_000_000) + try: + self.calibration_forward(state.model, state.data.calib) + finally: + pass + #torch.cuda.memory._dump_snapshot("bs256.pickle") + #torch.cuda.memory._record_memory_history(enabled=None) + #exit(0) self.remove_hooks() self.finish_compression(state.model) @@ -213,15 +223,22 @@ def on_initialize(self, state: "State", **kwargs) -> bool: def finish_compression(self, model: torch.nn.Module): for module in model.modules(): + quant_args = getattr_chain(module, "quantization_scheme.weights", None) + if quant_args is None: + continue + with align_module(module): - quant_args = getattr_chain(module, "quantization_scheme.weights", None) - if quant_args is None: - continue - weight = module.weight_acc / self._num_batches - delete_offload_parameter(module, "weight_acc") + if self.batch_size != -1: + weight = module.weight_acc / self._num_batches + delete_offload_parameter(module, "weight_acc") + else: + weight = module.weight - scale, zero_point = quant_args.get_observer()(weight) + observer = Observer.load_from_registry( + quant_args.observer, quantization_args=quant_args + ) + scale, zero_point = observer(weight) weight = fake_quantize( weight, scale, @@ -253,12 +270,18 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - dataloader = create_batch_dataloader(dataloader, batch_size=self.batch_size) + if self.batch_size <= 0: + batch_size = len(dataloader.dataset) + else: + batch_size = self.batch_size + dataloader = create_batch_dataloader(dataloader, batch_size=batch_size) with calibration_forward_context(model): run_calibration_forward(model, dataloader, mask_padding=True) def pre_compress_module(self, module: torch.nn.Module): - register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) + if self.batch_size != -1: + print("created aux buffers") + register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) def compress_module( self, @@ -292,8 +315,13 @@ def compress_module( module_class=type(module), ) - module.weight_acc += quantized_weight - update_offload_parameter(module, "weight_acc") + if self.batch_size != -1: + module.weight_acc += quantized_weight + update_offload_parameter(module, "weight_acc") + else: + module.weight -= module.weight + module.weight += quantized_weight + update_offload_parameter(module, "weight") return loss diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 9dcbac9d3..a3bd5b6a6 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -9,6 +9,7 @@ from torch.utils.hooks import RemovableHandle from collections import defaultdict +from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.utils.helpers import getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params @@ -168,6 +169,7 @@ def target_post_forward( ): print(f"post {name}") + @HooksMixin.hook def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs): logger.info( @@ -186,4 +188,8 @@ def layer_post_forward( ): print(f"post {name}") self._layer_index += 1 + + if name == "model.layers.31": + raise EarlyStopException(None, None) + return output diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 43b261d99..8c7b8b318 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -26,6 +26,8 @@ class EarlyStopException(Exception): """ def __init__(self, args: Tuple, kwargs: Dict): + if args is None: + return self.args = tensors_to_device(args, "cpu") self.kwargs = kwargs @@ -98,7 +100,8 @@ def run_calibration_forward( except EarlyStopException as e: # model was stopped early, save last calculated output and # move on to next calibration sample - intermediates.append((e.args, e.kwargs)) + #intermediates.append((e.args, e.kwargs)) + pass # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 211bd01eb..a45db89c1 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1127,7 +1127,7 @@ def align_module(module: torch.nn.Module, device: Optional[torch.device] = None) elif device is not None: devices = {} - for name, param in module.named_parameters(): + for name, param in module.named_parameters(recurse=False): devices[name] = param.device setattr(module, name, param.to(device)) From 593d4fda35e96514107124aa0cf49ce8191aeb67 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Nov 2024 16:42:22 -0500 Subject: [PATCH 073/285] support vision datasets --- examples/quantization_w4a16/llama3_example.py | 6 +- .../quantization_w4a16/vision2_example.py | 83 +++++++++++++++++ examples/quantization_w4a16/vision_example.py | 88 +++++++++++++++++++ .../modifiers/quantization/gptq/base.py | 3 +- src/llmcompressor/modifiers/utils/hooks.py | 24 ++--- .../transformers/finetune/data/base.py | 9 +- .../finetune/data/data_helpers.py | 4 +- .../transformers/finetune/text_generation.py | 12 +-- 8 files changed, 202 insertions(+), 27 deletions(-) create mode 100644 examples/quantization_w4a16/vision2_example.py create mode 100644 examples/quantization_w4a16/vision_example.py diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 939991ab6..f1545c992 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -9,7 +9,7 @@ model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, - device_map="auto", + device_map="cuda:0", torch_dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -20,7 +20,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 160 #2048 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -55,7 +55,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=-1, dampening_frac=0.5) # Apply algorithms. oneshot( diff --git a/examples/quantization_w4a16/vision2_example.py b/examples/quantization_w4a16/vision2_example.py new file mode 100644 index 000000000..1f57bb9f9 --- /dev/null +++ b/examples/quantization_w4a16/vision2_example.py @@ -0,0 +1,83 @@ +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot + +# Select model and load it. +MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" + +model = MllamaForConditionalGeneration.from_pretrained( + MODEL_ID, + device_map="cuda:0", + torch_dtype="auto", +) +processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 160 #2048 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + return { + "text": processor.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return processor( + None, + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=1, dampening_frac=0.5) + +# Apply algorithms. +oneshot( + model=model, + tokenizer=MODEL_ID, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(processor.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/examples/quantization_w4a16/vision_example.py b/examples/quantization_w4a16/vision_example.py new file mode 100644 index 000000000..f89ada21a --- /dev/null +++ b/examples/quantization_w4a16/vision_example.py @@ -0,0 +1,88 @@ +from datasets import load_dataset +from transformers import AutoProcessor + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot + +# Select model and load it. +MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" + +model = SparseAutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="cuda:0", + torch_dtype="auto", +) +breakpoint() +processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) + +# Select calibration dataset. +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:165]" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 165 #2048 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt", max_length=MAX_SEQUENCE_LENGTH) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=-1, dampening_frac=0.5) + +# Apply algorithms. +oneshot( + model=model, + tokenizer=MODEL_ID, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(processor.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 91ec2ed2f..74877bf93 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -209,7 +209,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.calibration_forward(state.model, state.data.calib) finally: pass - #torch.cuda.memory._dump_snapshot("bs256.pickle") + #torch.cuda.memory._dump_snapshot("bs10.pickle") #torch.cuda.memory._record_memory_history(enabled=None) #exit(0) @@ -306,6 +306,7 @@ def compress_module( logger.info(f"Using {inp.size(0)} samples") with align_module(module): + print(inp.shape) loss, quantized_weight, _scale, _zero_point, _g_idx = quantize_weight( module.weight.data, inp, diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index a3bd5b6a6..f6b359d4d 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -124,11 +124,11 @@ def register_hooks(self, model: torch.nn.Module): # if no targets are provided, default to the modules that shouldn't be # split by FSDP. For Transformers models this is equivalent to the # decoder layers (ie LlamaDecoderLayer) - sequential_targets = self.sequential_targets - if sequential_targets is None: - sequential_targets = get_no_split_params(model) - layers = get_layers(sequential_targets, model) - self._num_layers = len(layers) + # sequential_targets = self.sequential_targets + # if sequential_targets is None: + # sequential_targets = get_no_split_params(model) + # layers = get_layers(sequential_targets, model) + # self._num_layers = len(layers) for name, module in model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: @@ -139,13 +139,13 @@ def register_hooks(self, model: torch.nn.Module): self.pre_compress_module(module) - if name in layers.keys(): - pre_hook = partial(self.layer_pre_forward, name) - post_hook = partial(self.layer_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) - self.register_hook( - module.register_forward_hook(post_hook, with_kwargs=True) - ) + # if name in layers.keys(): + # pre_hook = partial(self.layer_pre_forward, name) + # post_hook = partial(self.layer_post_forward, name) + # self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) + # self.register_hook( + # module.register_forward_hook(post_hook, with_kwargs=True) + # ) @HooksMixin.hook diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index d4c3a6222..941306180 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -51,17 +51,20 @@ def __init__( self.padding = False if self.tokenizer: + if hasattr(self.tokenizer, "tokenizer"): + self.tokenizer = self.tokenizer.tokenizer + if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token # configure sequence length max_seq_length = data_args.max_seq_length - model_max_length = tokenizer.model_max_length if tokenizer else max_seq_length + model_max_length = self.tokenizer.model_max_length if self.tokenizer else max_seq_length if self.tokenizer and max_seq_length > model_max_length: logger.warning( f"The max_seq_length passed ({max_seq_length}) is larger than " - f"the maximum length for the model ({tokenizer.model_max_length}). " - f"Using max_seq_length={tokenizer.model_max_length}." + f"the maximum length for the model ({self.tokenizer.model_max_length}). " + f"Using max_seq_length={self.tokenizer.model_max_length}." ) self.max_seq_length = min(data_args.max_seq_length, model_max_length) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 92d73edc2..caa6b07d3 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -38,8 +38,8 @@ def create_batch_dataloader( def pad_sequences(batch): # extract input_ids and attention_mask from the batch - input_ids = [torch.tensor(item["input_ids"]) for item in batch] - masks = [torch.tensor(item["attention_mask"]) for item in batch] + input_ids = [torch.tensor(item["input_ids"]).squeeze(0) for item in batch] + masks = [torch.tensor(item["attention_mask"]).squeeze(0) for item in batch] # while 0 is not necessarily the "correct" padding value, the padded # input_ids are ignored according to the attention_mask diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 1856ca954..8db9ceac0 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -23,7 +23,7 @@ from loguru import logger from transformers import ( AutoConfig, - AutoTokenizer, + AutoProcessor, DefaultDataCollator, HfArgumentParser, set_seed, @@ -221,12 +221,12 @@ def initialize_model_from_path( def initialize_tokenizer_from_path(model_args, model, teacher): tokenizer_src = model_args.tokenizer tokenizer_src = tokenizer_src or get_shared_tokenizer_src(model, teacher) - tokenizer = AutoTokenizer.from_pretrained( + tokenizer = AutoProcessor.from_pretrained( tokenizer_src, - cache_dir=model_args.cache_dir, - use_fast=True, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, + # cache_dir=model_args.cache_dir, + # use_fast=True, + # revision=model_args.model_revision, + # use_auth_token=True if model_args.use_auth_token else None, trust_remote_code=model_args.trust_remote_code_model, ) From 0bdf98a327ef97c0b281de1896e23ed5a5d7d275 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 8 Nov 2024 12:56:26 -0500 Subject: [PATCH 074/285] use pixtral --- shubhra.py | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 shubhra.py diff --git a/shubhra.py b/shubhra.py new file mode 100644 index 000000000..4996c8277 --- /dev/null +++ b/shubhra.py @@ -0,0 +1,92 @@ +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot, wrap_hf_model_class +import os + +# Load model. +#model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +model_id = "mgoin/pixtral-12b" +model_class = wrap_hf_model_class(LlavaForConditionalGeneration) +model = model_class.from_pretrained(model_id, device_map="auto", torch_dtype="auto", trust_remote_code=True, _attn_implementation="eager",) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +print("Loading dataset") +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:128]" + +NUM_CALIBRATION_SAMPLES = 1#128 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +print("Preprocessing samples") +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt") + + +ds = ds.map(tokenize, remove_columns=ds.column_names) +print(ds) + +print("Setting up quantization params") +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] +#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] + +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore), +] + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, +) + +#processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") From 9f43b5d56ed54233efc791cde4f021b6cf9a0cd7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 8 Nov 2024 12:57:11 -0500 Subject: [PATCH 075/285] better stopping --- src/llmcompressor/modifiers/utils/hooks.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index f6b359d4d..267de838c 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -139,6 +139,12 @@ def register_hooks(self, model: torch.nn.Module): self.pre_compress_module(module) + if "head" in name: + def hook(module: torch.nn.Module, args: Tuple[Any, ...]): + raise EarlyStopException(None, None) + + self.register_hook(module.register_forward_pre_hook(hook, with_kwargs=False)) + # if name in layers.keys(): # pre_hook = partial(self.layer_pre_forward, name) # post_hook = partial(self.layer_post_forward, name) @@ -151,7 +157,7 @@ def register_hooks(self, model: torch.nn.Module): @HooksMixin.hook def target_pre_forward( self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] - ): + ): # compress print(f"compressing {name}") if True: #self.true_sequential: @@ -189,7 +195,4 @@ def layer_post_forward( print(f"post {name}") self._layer_index += 1 - if name == "model.layers.31": - raise EarlyStopException(None, None) - return output From 3d224db4101a0ab8a7bf07223ffa8168db6a7a53 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 13 Nov 2024 21:34:42 +0000 Subject: [PATCH 076/285] implement partitioned model --- examples/quantization_w4a16/llama3_example.py | 7 +- graph_resuming.py | 307 +++++++++++++++++ .../modifiers/quantization/gptq/base.py | 6 +- .../gptq/utils/partitioned_model.py | 315 ++++++++++++++++++ 4 files changed, 632 insertions(+), 3 deletions(-) create mode 100644 graph_resuming.py create mode 100644 src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index f1545c992..96f80051e 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,3 +1,4 @@ +from accelerate import cpu_offload from datasets import load_dataset from transformers import AutoTokenizer @@ -5,13 +6,15 @@ from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot # Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="cuda:0", torch_dtype="auto", ) +# cpu_offload(model) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. @@ -20,7 +23,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 160 #2048 +NUM_CALIBRATION_SAMPLES = 2 #2048 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. diff --git a/graph_resuming.py b/graph_resuming.py new file mode 100644 index 000000000..194e6232c --- /dev/null +++ b/graph_resuming.py @@ -0,0 +1,307 @@ +from typing import Any, Callable, Dict, List, Set + +import torch +import inspect +from collections import deque +from transformers import AutoModel +from torch.fx import GraphModule, Graph, Node +from transformers.modeling_outputs import BaseModelOutputWithPast + + +class Model(torch.nn.Module): + def __init__(self, vocab_size=4096, d_model=128, n_heads=1, d_ff=256, dropout=0.1): + super(Model, self).__init__() + + self.d_model = d_model + self.n_heads = n_heads + self.head_dim = d_model // n_heads + assert d_model % n_heads == 0, "d_model must be divisible by n_heads" + + # Embedding layer + self.embedding = torch.nn.Embedding(vocab_size, d_model) + + # Linear transformations for queries, keys, and values + self.query_linear = torch.nn.Linear(d_model, d_model) + self.key_linear = torch.nn.Linear(d_model, d_model) + self.value_linear = torch.nn.Linear(d_model, d_model) + + # Output linear layer to combine heads + self.out_linear = torch.nn.Linear(d_model, d_model) + + # Position-wise feed-forward network + self.feed_forward = torch.nn.Sequential( + torch.nn.Linear(d_model, d_ff), + torch.nn.ReLU(), + torch.nn.Linear(d_ff, d_model) + ) + + # Layer normalization layers + self.norm1 = torch.nn.LayerNorm(d_model) + self.norm2 = torch.nn.LayerNorm(d_model) + + # Dropout layer + self.dropout = torch.nn.Dropout(dropout) + + def scaled_dot_product_attention(self, query, key, value): + # Calculate attention scores + scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) + attn_weights = torch.functional.F.softmax(scores, dim=-1) + output = torch.matmul(attn_weights, value) + return output + + def forward(self, input_ids): + # Apply embedding layer + x = self.embedding(input_ids) # (batch_size, seq_length, d_model) + + batch_size, seq_length, _ = x.size() + + # Linear projections + Q = self.query_linear(x) # (batch_size, seq_length, d_model) + K = self.key_linear(x) # (batch_size, seq_length, d_model) + V = self.value_linear(x) # (batch_size, seq_length, d_model) + + # Split Q, K, V into multiple heads + Q = Q.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) + K = K.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) + V = V.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) + + # Scaled dot-product attention + attn_output = self.scaled_dot_product_attention(Q, K, V) # (batch_size, n_heads, seq_length, head_dim) + + # Concatenate heads + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model) + + # Apply final linear transformation + attn_output = self.out_linear(attn_output) + + # Add & Norm + x = x + self.dropout(attn_output) + x = self.norm1(x) + + # Feed-forward block + ff_output = self.feed_forward(x) + x = x + self.dropout(ff_output) + x = self.norm2(x) + + return BaseModelOutputWithPast(last_hidden_state=x) + + +def get_target_nodes(graph: GraphModule, targets: List[str]): + target_nodes = [] + for node in graph.graph.nodes: + if ( + node.op == "call_module" and + type(graph.get_submodule(node.target)).__name__ in targets + ): + target_nodes.append(node) + + return target_nodes + + +def check_assumption(graph: Graph) -> bool: + for node in graph.nodes: + for user in node.users: + if node not in user.all_input_nodes: + return False + + for input_node in node.all_input_nodes: + if node not in input_node.users: + return False + + if ( + len(node.users) != len(set(node.users)) or + len(node.all_input_nodes) != len(set(node.all_input_nodes)) + ): + return False + + return True + + +def topological_partition(graph: GraphModule, target_nodes: Set[Node]) -> List[List[Node]]: + # use list representation to maintain topological sorting + assert check_assumption(graph.graph) + + partitions: List[List[Node]] = [[]] + remaining_indegrees = {node: len(node.all_input_nodes) for node in graph.graph.nodes} + partition_index = 0 # global counter, not necessary but ensures partitions are connected + + # start with graph input nodes + queue = deque(node for node in graph.graph.nodes if remaining_indegrees[node] == 0) + while len(queue) > 0: + node = queue.popleft() + + # guarantee targets are assigned to disjoint partitions + if node in target_nodes: + partition_index += 1 + partitions.append([]) + + # assign to partition + partitions[partition_index].append(node) + + # recurse on last indegree only in order to guarantee that + # the node is assigned to maximal partition + for user in node.users: + remaining_indegrees[user] -= 1 + if remaining_indegrees[user] == 0: + queue.append(user) + + assert set().union(*partitions) == set(graph.graph.nodes) + return partitions + + +def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): + subgraphs = [] + + # create subgraphs + for partition_nodes in partitions: + # create a new graph for the partition + subgraph = Graph(model) + node_map = {} + + # add placeholders for inputs not in this subgraph. use set to deduplicate + new_input_nodes = { + input_node + for node in partition_nodes + for input_node in node.all_input_nodes + if input_node not in partition_nodes + } + for input_node in new_input_nodes: + node_map[input_node] = subgraph.placeholder(input_node.name) + + # add the nodes to subgraph + for node in partition_nodes: + node_map[node] = subgraph.node_copy(node, lambda n: node_map[n]) + + # add an output node to collect all subgraph outputs into a dictionary + if len(subgraph.find_nodes(op="output")) <= 0: + output_dict = { + node.name: node_map[node] + for node in partition_nodes + if any(user not in partition_nodes for user in node.users.keys()) + } + subgraph.output(output_dict) + + # Save the subgraph for this partition + subgraph.lint() + input_names = [node.name for node in subgraph.nodes if node.op == "placeholder"] + subgraphs.append({ + "graph": subgraph, + "code": subgraph.python_code("self"), + "input_names": input_names, + "consumed_names": [], + }) + + print([n for n in subgraph.nodes]) + assert check_assumption(subgraph) + + # populate consumed_names according to when inputs are last used + # in order to vacate the `intermediates` cache and save memory + all_input_names = set().union(*(subgraph["input_names"] for subgraph in subgraphs)) + for input_name in all_input_names: + for subgraph in reversed(subgraphs): + if input_name in subgraph["input_names"]: + subgraph["consumed_names"].append(input_name) + break + else: + assert False + + return subgraphs + + +def gptq_compress(name: str, module: torch.nn.Module, inputs: List[torch.Tensor]): + print(f"gptq_compress {name} {module} {inputs.shape}") + pass + + +class HookedModel: + def __init__(self): + self.hook_targets = [] + self.hook_target_nodes = [] + self.graph = None + self.subgraphs = [] + self.model = None + + def register_hook(self, func: Callable, targets: List[str]): + self.hook_targets.append((func, targets)) + + def init_forward(self, model: torch.nn.Module): + self.model = model + + # 1. create graph + self.graph: GraphModule = symbolic_trace(model) + + # 2. identify target nodes + for func, targets in self.hook_targets: + self.hook_target_nodes.append((func, get_target_nodes(self.graph, targets))) + + all_target_nodes = set().union(*(target_nodes for _, target_nodes in self.hook_target_nodes)) + + # 3. cut into partitions along target nodes + partitions: List[List[Node]] = topological_partition(self.graph, all_target_nodes) + self.subgraphs: List[GraphModule] = partition_graph(model, partitions) + + def forward(self, *args, **kwargs): + model_modules = {name: module for name, module in self.model.named_modules()} + + # 4. perform compression + intermediates = kwargs.copy() + for subgraph_index, subgraph in enumerate(self.subgraphs): + code = subgraph["code"] + exec(code.src, code.globals) + forward_function = code.globals.get("forward") + + inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + + # detect and call hooks + for func, target_nodes in self.hook_target_nodes: + target_nodes = set(target_node for target_node in target_nodes) + subgraph_node_names = set(node.name for node in subgraph["graph"].nodes if node.op == "call_module") + + for target_node in target_nodes: + if target_node.name in subgraph_node_names: + assert len(target_node.all_input_nodes) == 1 + + module = model_modules[target_node.target] + input_value = inputs[target_node.all_input_nodes[0].name] + func(target_node.target, module, input_value) + + if subgraph_index < len(self.subgraphs) - 1: + intermediates.update(forward_function(self.model, **inputs)) + + for consumed_name in subgraph["consumed_names"]: + del intermediates[consumed_name] + else: + return forward_function(self.model, **inputs) + + +if __name__ == "__main__": + use_dummy_model = True + sequence_length = 2048 + + if use_dummy_model: + model = Model() + from torch.fx import symbolic_trace + else: + model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") + from transformers.utils.fx import symbolic_trace + + data_loader = [ + {"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, + #{"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, + #{"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, + ] + + # modifier inits + hooked_model = HookedModel() + hooked_model.register_hook(gptq_compress, ["Linear"]) + + # some time after modifier inits but before forward passes + hooked_model.init_forward(model) + + # oneshot/ eval loop + model.eval() + with torch.no_grad(): + for batch in data_loader: + hooked_output = hooked_model.forward(**batch) + model_output = model.forward(**batch) + assert torch.equal(hooked_output["last_hidden_state"], model_output["last_hidden_state"]) \ No newline at end of file diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 74877bf93..3618af76d 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -13,6 +13,7 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight +from llmcompressor.modifiers.quantization.gptq.utils.partitioned_model import PartitionedModel from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -276,7 +277,10 @@ def calibration_forward( batch_size = self.batch_size dataloader = create_batch_dataloader(dataloader, batch_size=batch_size) with calibration_forward_context(model): - run_calibration_forward(model, dataloader, mask_padding=True) + partitioned_model = PartitionedModel() + partitioned_model.init_forward(model, ["Linear"]) + run_calibration_forward(partitioned_model, dataloader, mask_padding=True) + def pre_compress_module(self, module: torch.nn.Module): if self.batch_size != -1: diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py new file mode 100644 index 000000000..761f7dd28 --- /dev/null +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -0,0 +1,315 @@ + +from typing import Any, Callable, Dict, List, Set + +import torch +from collections import deque +from transformers import AutoModel +from torch.fx import GraphModule, Graph, Node +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils.fx import symbolic_trace + + +class Model(torch.nn.Module): + def __init__(self, vocab_size=4096, d_model=128, n_heads=1, d_ff=256, dropout=0.1): + super(Model, self).__init__() + + self.d_model = d_model + self.n_heads = n_heads + self.head_dim = d_model // n_heads + assert d_model % n_heads == 0, "d_model must be divisible by n_heads" + + # Embedding layer + self.embedding = torch.nn.Embedding(vocab_size, d_model) + + # Linear transformations for queries, keys, and values + self.query_linear = torch.nn.Linear(d_model, d_model) + self.key_linear = torch.nn.Linear(d_model, d_model) + self.value_linear = torch.nn.Linear(d_model, d_model) + + # Output linear layer to combine heads + self.out_linear = torch.nn.Linear(d_model, d_model) + + # Position-wise feed-forward network + self.feed_forward = torch.nn.Sequential( + torch.nn.Linear(d_model, d_ff), + torch.nn.ReLU(), + torch.nn.Linear(d_ff, d_model) + ) + + # Layer normalization layers + self.norm1 = torch.nn.LayerNorm(d_model) + self.norm2 = torch.nn.LayerNorm(d_model) + + # Dropout layer + self.dropout = torch.nn.Dropout(dropout) + + def scaled_dot_product_attention(self, query, key, value): + # Calculate attention scores + scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) + attn_weights = torch.functional.F.softmax(scores, dim=-1) + output = torch.matmul(attn_weights, value) + return output + + def forward(self, input_ids): + # Apply embedding layer + x = self.embedding(input_ids) # (batch_size, seq_length, d_model) + + batch_size, seq_length, _ = x.size() + + # Linear projections + Q = self.query_linear(x) # (batch_size, seq_length, d_model) + K = self.key_linear(x) # (batch_size, seq_length, d_model) + V = self.value_linear(x) # (batch_size, seq_length, d_model) + + # Split Q, K, V into multiple heads + Q = Q.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) + K = K.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) + V = V.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) + + # Scaled dot-product attention + attn_output = self.scaled_dot_product_attention(Q, K, V) # (batch_size, n_heads, seq_length, head_dim) + + # Concatenate heads + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model) + + # Apply final linear transformation + attn_output = self.out_linear(attn_output) + + # Add & Norm + x = x + self.dropout(attn_output) + x = self.norm1(x) + + # Feed-forward block + ff_output = self.feed_forward(x) + x = x + self.dropout(ff_output) + x = self.norm2(x) + + return BaseModelOutputWithPast(last_hidden_state=x) + + +def get_target_nodes(graph: GraphModule, targets: List[str]): + target_nodes = [] + for node in graph.graph.nodes: + if ( + node.op == "call_module" and + type(graph.get_submodule(node.target)).__name__ in targets + ): + target_nodes.append(node) + + return target_nodes + + +def check_assumption(graph: Graph) -> bool: + for node in graph.nodes: + for user in node.users: + if node not in user.all_input_nodes: + return False + + for input_node in node.all_input_nodes: + if node not in input_node.users: + return False + + if ( + len(node.users) != len(set(node.users)) or + len(node.all_input_nodes) != len(set(node.all_input_nodes)) + ): + return False + + return True + + +def topological_partition(graph: GraphModule, target_nodes: Set[Node]) -> List[List[Node]]: + # use list representation to maintain topological sorting + assert check_assumption(graph.graph) + + partitions: List[List[Node]] = [[]] + remaining_indegrees = {node: len(node.all_input_nodes) for node in graph.graph.nodes} + partition_index = 0 # global counter, not necessary but ensures partitions are connected + + # start with graph input nodes + queue = deque(node for node in graph.graph.nodes if remaining_indegrees[node] == 0) + while len(queue) > 0: + node = queue.popleft() + + # guarantee targets are assigned to disjoint partitions + if node in target_nodes: + partition_index += 1 + partitions.append([]) + + # assign to partition + partitions[partition_index].append(node) + + # recurse on last indegree only in order to guarantee that + # the node is assigned to maximal partition + for user in node.users: + remaining_indegrees[user] -= 1 + if remaining_indegrees[user] == 0: + queue.append(user) + + assert set().union(*partitions) == set(graph.graph.nodes) + return partitions + + +def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): + subgraphs = [] + + # create subgraphs + for partition_nodes in partitions: + # create a new graph for the partition + subgraph = Graph(model) + node_map = {} + + # add placeholders for inputs not in this subgraph. use set to deduplicate + new_input_nodes = { + input_node + for node in partition_nodes + for input_node in node.all_input_nodes + if input_node not in partition_nodes + } + for input_node in new_input_nodes: + node_map[input_node] = subgraph.placeholder(input_node.name) + + # add the nodes to subgraph + for node in partition_nodes: + node_map[node] = subgraph.node_copy(node, lambda n: node_map[n]) + + # add an output node to collect all subgraph outputs into a dictionary + if len(subgraph.find_nodes(op="output")) <= 0: + output_dict = { + node.name: node_map[node] + for node in partition_nodes + if any(user not in partition_nodes for user in node.users.keys()) + } + subgraph.output(output_dict) + + # Save the subgraph for this partition + subgraph.lint() + input_names = [node.name for node in subgraph.nodes if node.op == "placeholder"] + subgraphs.append({ + "graph": subgraph, + "code": subgraph.python_code("self"), + "input_names": input_names, + "consumed_names": [], + }) + + print([n for n in subgraph.nodes]) + assert check_assumption(subgraph) + + # populate consumed_names according to when inputs are last used + # in order to vacate the `intermediates` cache and save memory + all_input_names = set().union(*(subgraph["input_names"] for subgraph in subgraphs)) + for input_name in all_input_names: + for subgraph in reversed(subgraphs): + if input_name in subgraph["input_names"]: + subgraph["consumed_names"].append(input_name) + break + else: + assert False + + return subgraphs + + +def gptq_compress(name: str, module: torch.nn.Module, inputs: List[torch.Tensor]): + print(f"gptq_compress {name} {module} {inputs.shape}") + pass + + +class PartitionedModel: + def __init__(self): + self.hook_targets = [] + self.hook_target_nodes = [] + self.graph = None + self.subgraphs = [] + self.model = None + + def register_hook(self, func: Callable, targets: List[str]): + self.hook_targets.append((func, targets)) + + def init_forward(self, model: torch.nn.Module, targets): + self.model = model + + # 1. create graph + self.graph: GraphModule = symbolic_trace(model) + + # 2. identify target nodes + all_target_nodes = get_target_nodes(self.graph, targets) + + # 3. cut into partitions along target nodes + partitions: List[List[Node]] = topological_partition(self.graph, all_target_nodes) + self.subgraphs: List[GraphModule] = partition_graph(model, partitions) + + def forward(self, *args, **kwargs): + model_modules = {name: module for name, module in self.model.named_modules()} + + # 4. perform compression + intermediates = kwargs.copy() + for subgraph_index, subgraph in enumerate(self.subgraphs): + code = subgraph["code"] + exec(code.src, code.globals) + forward_function = code.globals.get("forward") + + inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + + # detect and call hooks + for func, target_nodes in self.hook_target_nodes: + target_nodes = set(target_node for target_node in target_nodes) + subgraph_node_names = set(node.name for node in subgraph["graph"].nodes if node.op == "call_module") + + for target_node in target_nodes: + if target_node.name in subgraph_node_names: + assert len(target_node.all_input_nodes) == 1 + + module = model_modules[target_node.target] + input_value = inputs[target_node.all_input_nodes[0].name] + func(target_node.target, module, input_value) + + if subgraph_index < len(self.subgraphs) - 1: + intermediates.update(forward_function(self.model, **inputs)) + + for consumed_name in subgraph["consumed_names"]: + del intermediates[consumed_name] + else: + return forward_function(self.model, **inputs) + + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def eval(self): + self.model.eval() + + def parameters(self): + return self.model.parameters() + + +if __name__ == "__main__": + use_dummy_model = True + sequence_length = 2048 + + if use_dummy_model: + model = Model() + from torch.fx import symbolic_trace + else: + model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") + from transformers.utils.fx import symbolic_trace + + data_loader = [ + {"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, + #{"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, + #{"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, + ] + + # modifier inits + hooked_model = HookedModel() + hooked_model.register_hook(gptq_compress, ["Linear"]) + + # some time after modifier inits but before forward passes + hooked_model.init_forward(model) + + # oneshot/ eval loop + model.eval() + with torch.no_grad(): + for batch in data_loader: + hooked_output = hooked_model.forward(**batch) + model_output = model.forward(**batch) + assert torch.equal(hooked_output["last_hidden_state"], model_output["last_hidden_state"]) From 4872242f8d6a8f7d15e4d6528f32197e832a4550 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 01:27:22 +0000 Subject: [PATCH 077/285] working, although still a little higher memory usage than expected --- examples/quantization_w4a16/llama3_example.py | 4 +- graph_resuming.py | 3 +- .../modifiers/quantization/gptq/base.py | 133 ++++++++---------- .../quantization/gptq/utils/gptq_quantize.py | 31 +++- .../gptq/utils/partitioned_model.py | 86 ++++++++--- src/llmcompressor/modifiers/utils/hooks.py | 25 +--- src/llmcompressor/utils/fsdp/helpers.py | 2 + 7 files changed, 161 insertions(+), 123 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 96f80051e..9855490d3 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -23,7 +23,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 2 #2048 +NUM_CALIBRATION_SAMPLES = 285 #2048 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -58,7 +58,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=-1, dampening_frac=0.5) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], update_size=NUM_CALIBRATION_SAMPLES, dampening_frac=0.5) # Apply algorithms. oneshot( diff --git a/graph_resuming.py b/graph_resuming.py index 194e6232c..14e0520ea 100644 --- a/graph_resuming.py +++ b/graph_resuming.py @@ -162,6 +162,7 @@ def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): new_input_nodes = { input_node for node in partition_nodes + if node.op != "get_attr" for input_node in node.all_input_nodes if input_node not in partition_nodes } @@ -275,7 +276,7 @@ def forward(self, *args, **kwargs): if __name__ == "__main__": - use_dummy_model = True + use_dummy_model = False sequence_length = 2048 if use_dummy_model: diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 3618af76d..3909178d3 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -12,11 +12,11 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory -from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import accumulate_hessian, make_empty_hessian, quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.partitioned_model import PartitionedModel from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException, run_calibration_forward from llmcompressor.observers.base import Observer from llmcompressor.transformers.finetune.data.data_helpers import ( create_batch_dataloader, @@ -31,7 +31,8 @@ fake_quantize, ) -from llmcompressor.utils.pytorch.module import qat_active +from llmcompressor.utils.metric_logging import CompressionLogger +from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active __all__ = ["GPTQModifier"] @@ -104,7 +105,7 @@ class GPTQModifier(Modifier, LayerCompressorMixin): """ sequential_update: bool = True # DEPRECIATED - batch_size: int = -1 + update_size: int = 1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 @@ -119,7 +120,8 @@ class GPTQModifier(Modifier, LayerCompressorMixin): disable_quantization_observer_epoch: Optional[float] = None _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() - _num_batches: int = PrivateAttr() + _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=lambda: {}) + _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=lambda: {}) @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -197,58 +199,24 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - - if self.batch_size <= 0: - batch_size = len(state.data.calib.dataset) - else: - batch_size = self.batch_size - self._num_batches = math.ceil(len(state.data.calib.dataset) / batch_size) self.register_hooks(state.model) - #torch.cuda.memory._record_memory_history(max_entries=1_000_000) + torch.cuda.memory._record_memory_history(max_entries=1_000_000) try: self.calibration_forward(state.model, state.data.calib) finally: pass - #torch.cuda.memory._dump_snapshot("bs10.pickle") - #torch.cuda.memory._record_memory_history(enabled=None) - #exit(0) + torch.cuda.memory._dump_snapshot("partition.pickle") + torch.cuda.memory._record_memory_history(enabled=None) + exit(0) self.remove_hooks() - self.finish_compression(state.model) # freeze quantization state.model.apply(freeze_module_quantization) return True - - def finish_compression(self, model: torch.nn.Module): - for module in model.modules(): - quant_args = getattr_chain(module, "quantization_scheme.weights", None) - if quant_args is None: - continue - - with align_module(module): - - if self.batch_size != -1: - weight = module.weight_acc / self._num_batches - delete_offload_parameter(module, "weight_acc") - else: - weight = module.weight - - observer = Observer.load_from_registry( - quant_args.observer, quantization_args=quant_args - ) - scale, zero_point = observer(weight) - weight = fake_quantize( - weight, - scale, - zero_point, - quant_args, - ) - update_offload_parameter(module, "weight", weight) - update_offload_parameter(module, "weight_scale", scale) - update_offload_parameter(module, "weight_zero_point", zero_point) + def on_finalize(self, state: "State", **kwargs) -> bool: """ @@ -271,28 +239,27 @@ def calibration_forward( :param model: model to perform forward pass with :param dataloader: dataloader containing calibration dataset """ - if self.batch_size <= 0: - batch_size = len(dataloader.dataset) - else: - batch_size = self.batch_size - dataloader = create_batch_dataloader(dataloader, batch_size=batch_size) + dataloader = create_batch_dataloader(dataloader, batch_size=1) with calibration_forward_context(model): partitioned_model = PartitionedModel() - partitioned_model.init_forward(model, ["Linear"]) - run_calibration_forward(partitioned_model, dataloader, mask_padding=True) - - - def pre_compress_module(self, module: torch.nn.Module): - if self.batch_size != -1: - print("created aux buffers") - register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) + targets = get_no_split_params(model) + partitioned_model.init_forward(model, targets) + breakpoint() + + model.config.use_cache = False + model.eval() + with torch.no_grad(): + try: + partitioned_model.forward_data(dataloader, mask_padding=True) + except EarlyStopException: + pass def compress_module( self, name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...], - ) -> float: + ): """ Quantize a module's weight according to the GPTQ algorithm @@ -302,33 +269,47 @@ def compress_module( :return: total loss from applying weight quantization to this module """ - logger.info(f"Quantizing {name}...") # Assume that first argument is the input inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - logger.info(f"Using {inp.size(0)} samples") - - with align_module(module): - print(inp.shape) - loss, quantized_weight, _scale, _zero_point, _g_idx = quantize_weight( - module.weight.data, - inp, - quant_args, - blocksize=self.block_size, - percdamp=self.dampening_frac, - module_class=type(module), - ) - if self.batch_size != -1: - module.weight_acc += quantized_weight - update_offload_parameter(module, "weight_acc") - else: + if module not in self._num_samples: + self._hessians[module] = make_empty_hessian(module) + self._num_samples[module] = 0 + + self._hessians[module], self._num_samples[module] = accumulate_hessian( + inp, + type(module), + self._hessians[module], + self._num_samples[module], + ) + + if self._num_samples[module] >= self.update_size: + logger.info(f"Quantizing {name}...") + logger.info(f"Using {self._num_samples[module]} accumulated samples") + with align_module(module), CompressionLogger(module) as comp_logger: + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + module.weight.data, + inp, + quant_args, + blocksize=self.block_size, + percdamp=self.dampening_frac, + module_class=type(module), + ) + module.weight -= module.weight module.weight += quantized_weight update_offload_parameter(module, "weight") + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) + if g_idx is not None: + update_offload_parameter(module, "weight_g_idx", g_idx) + + del self._hessians[module] + del self._num_samples[module] - return loss + comp_logger.set_loss(loss) def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index a625e8a7b..384d9fc8e 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -18,6 +18,35 @@ GPTQ_PRECISION = torch.float32 +def make_empty_hessian(module: torch.nn.Module): + weight = module.weight + num_columns = weight.shape[1] + return torch.zeros((num_columns, num_columns), device=weight.device, dtype=GPTQ_PRECISION) + + +def accumulate_hessian(inp: torch.Tensor, module_class: Type[torch.nn.Module], H: Optional[torch.Tensor] = None, num_samples: int = 1) -> Tuple[torch.Tensor, int]: + inp = inp.to(device=H.device) + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + num_added = inp.shape[0] # note this is the number of dataset samples, not + # multiplied by the sequence length + + if module_class in (torch.nn.Linear, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + H *= num_samples / (num_samples + num_added) + num_samples += num_added + + inp = inp.to(dtype=GPTQ_PRECISION) + inp = math.sqrt(2 / num_samples) * inp + H += inp.matmul(inp.t()) + + return H, num_samples + + def compute_hessian(inp: torch.Tensor, module_class: Type[torch.nn.Module], device) -> torch.Tensor: """ Calculate the hessian with respect to the module inputs @@ -255,7 +284,7 @@ def quantize_weight( W = W.reshape(final_shape).to(final_dtype) loss = torch.sum(losses).item() - return loss, W, scale, zero_point, g_idx + return loss, W, scale.to(dtype=final_dtype), zero_point.to(dtype=quant_args.pytorch_dtype()), g_idx def _apply_activation_ordering( diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 761f7dd28..0e2042add 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -6,7 +6,10 @@ from transformers import AutoModel from torch.fx import GraphModule, Graph, Node from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.utils.fx import symbolic_trace +from transformers.utils.fx import symbolic_trace, HFTracer + +from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pytorch.utils.helpers import tensors_to_device class Model(torch.nn.Module): @@ -123,11 +126,13 @@ def topological_partition(graph: GraphModule, target_nodes: Set[Node]) -> List[L assert check_assumption(graph.graph) partitions: List[List[Node]] = [[]] - remaining_indegrees = {node: len(node.all_input_nodes) for node in graph.graph.nodes} + remaining_indegrees = {node: len([node for node in node.all_input_nodes if node.op != "get_attr"]) for node in graph.graph.nodes} + #remaining_indegrees = {node: len((node for node in node.all_input_nodes)) for node in graph.graph.nodes} partition_index = 0 # global counter, not necessary but ensures partitions are connected # start with graph input nodes - queue = deque(node for node in graph.graph.nodes if remaining_indegrees[node] == 0) + #queue = deque(node for node in graph.graph.nodes if remaining_indegrees[node] == 0)# and node.op != "get_attr") + queue = deque(node for node in graph.graph.nodes if remaining_indegrees[node] == 0 and node.op != "get_attr") while len(queue) > 0: node = queue.popleft() @@ -146,6 +151,17 @@ def topological_partition(graph: GraphModule, target_nodes: Set[Node]) -> List[L if remaining_indegrees[user] == 0: queue.append(user) + for node in graph.graph.nodes: + if node.op == "get_attr": + user_partitions = [] + for user in node.users: + for index in range(len(partitions)): + if user in partitions[index]: + user_partitions.append(index) + break + partition_index = min(user_partitions) + partitions[partition_index].insert(0, node) + assert set().union(*partitions) == set(graph.graph.nodes) return partitions @@ -163,8 +179,9 @@ def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): new_input_nodes = { input_node for node in partition_nodes + #if node.op != "get_attr" for input_node in node.all_input_nodes - if input_node not in partition_nodes + if input_node not in partition_nodes and input_node.op } for input_node in new_input_nodes: node_map[input_node] = subgraph.placeholder(input_node.name) @@ -228,8 +245,14 @@ def register_hook(self, func: Callable, targets: List[str]): def init_forward(self, model: torch.nn.Module, targets): self.model = model - # 1. create graph - self.graph: GraphModule = symbolic_trace(model) + # 1. trace graph + class CustomTracer(HFTracer): + def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> bool: + if type(module).__name__ in targets: + return True # Treat as leaf, skip tracing inside this module + return super().is_leaf_module(module, module_qualified_name) + + self.graph: GraphModule = symbolic_trace(model, tracer_cls=CustomTracer) # 2. identify target nodes all_target_nodes = get_target_nodes(self.graph, targets) @@ -238,30 +261,48 @@ def init_forward(self, model: torch.nn.Module, targets): partitions: List[List[Node]] = topological_partition(self.graph, all_target_nodes) self.subgraphs: List[GraphModule] = partition_graph(model, partitions) - def forward(self, *args, **kwargs): - model_modules = {name: module for name, module in self.model.named_modules()} - + def forward_data(self, dataloader, mask_padding: bool = True): # 4. perform compression - intermediates = kwargs.copy() + model_device = next(self.model.parameters()).device + batch_intermediates = [ + tensors_to_device(apply_pad_mask_to_batch(batch), model_device) if mask_padding else tensors_to_device(batch, model_device) + for batch in dataloader + ] + batch_outputs = [None for _ in range(len(dataloader))] + for subgraph_index, subgraph in enumerate(self.subgraphs): code = subgraph["code"] exec(code.src, code.globals) forward_function = code.globals.get("forward") - inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + print(f"subgraph_index: {subgraph_index}") + print(batch_intermediates[0].keys()) + + for batch_index in range(len(dataloader)): + intermediates = batch_intermediates[batch_index] + + inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + subgraph_output = forward_function(self.model, **inputs) + + for consumed_name in subgraph["consumed_names"]: + del intermediates[consumed_name] - # detect and call hooks - for func, target_nodes in self.hook_target_nodes: - target_nodes = set(target_node for target_node in target_nodes) - subgraph_node_names = set(node.name for node in subgraph["graph"].nodes if node.op == "call_module") + if subgraph_index < len(self.subgraphs) - 1: + intermediates.update(subgraph_output) + else: + batch_outputs[batch_index] = subgraph_output - for target_node in target_nodes: - if target_node.name in subgraph_node_names: - assert len(target_node.all_input_nodes) == 1 + return batch_outputs - module = model_modules[target_node.target] - input_value = inputs[target_node.all_input_nodes[0].name] - func(target_node.target, module, input_value) + def forward(self, *args, **kwargs): + # 4. perform compression + intermediates = kwargs.copy() + for subgraph_index, subgraph in enumerate(self.subgraphs): + code = subgraph["code"] + exec(code.src, code.globals) + forward_function = code.globals.get("forward") + + inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} if subgraph_index < len(self.subgraphs) - 1: intermediates.update(forward_function(self.model, **inputs)) @@ -300,8 +341,7 @@ def parameters(self): ] # modifier inits - hooked_model = HookedModel() - hooked_model.register_hook(gptq_compress, ["Linear"]) + hooked_model = PartitionedModel() # some time after modifier inits but before forward passes hooked_model.init_forward(model) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 267de838c..14fe90e50 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -101,22 +101,13 @@ class LayerCompressorMixin(HooksMixin): _layer_index = 0 _num_layers = 0 - @abstractmethod - def pre_compress_module( - self, - name: str, - module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], - ) -> float: - raise NotImplementedError() - @abstractmethod def compress_module( self, name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...], - ) -> float: + ): raise NotImplementedError() def register_hooks(self, model: torch.nn.Module): @@ -133,17 +124,15 @@ def register_hooks(self, model: torch.nn.Module): for name, module in model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: pre_hook = partial(self.target_pre_forward, name) - post_hook = partial(self.target_post_forward, name) + #post_hook = partial(self.target_post_forward, name) self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) - self.register_hook(module.register_forward_hook(post_hook)) - - self.pre_compress_module(module) + #self.register_hook(module.register_forward_hook(post_hook)) if "head" in name: def hook(module: torch.nn.Module, args: Tuple[Any, ...]): raise EarlyStopException(None, None) - self.register_hook(module.register_forward_pre_hook(hook, with_kwargs=False)) + # self.register_hook(module.register_forward_pre_hook(hook, with_kwargs=False)) # if name in layers.keys(): # pre_hook = partial(self.layer_pre_forward, name) @@ -159,11 +148,7 @@ def target_pre_forward( self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): # compress - print(f"compressing {name}") - if True: #self.true_sequential: - with CompressionLogger(module) as comp_logger: - loss = self.compress_module(name, module, args) - comp_logger.set_loss(loss) + self.compress_module(name, module, args) @HooksMixin.hook def target_post_forward( diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index e58b4f1c3..f2f902344 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -299,6 +299,8 @@ def update_offload_parameter( raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") if param.data.dtype != data.dtype: + print(name) + print((param.data.dtype, data.dtype)) warnings.warn("TODO") param.data.copy_(data) From 7fa5c3c83ba7e2cf6cbf84fd940781a43631c73e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 01:32:44 +0000 Subject: [PATCH 078/285] offload intermediates --- .../modifiers/quantization/gptq/utils/partitioned_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 0e2042add..3f4455e83 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -265,7 +265,7 @@ def forward_data(self, dataloader, mask_padding: bool = True): # 4. perform compression model_device = next(self.model.parameters()).device batch_intermediates = [ - tensors_to_device(apply_pad_mask_to_batch(batch), model_device) if mask_padding else tensors_to_device(batch, model_device) + apply_pad_mask_to_batch(batch) if mask_padding else batch for batch in dataloader ] batch_outputs = [None for _ in range(len(dataloader))] @@ -282,7 +282,9 @@ def forward_data(self, dataloader, mask_padding: bool = True): intermediates = batch_intermediates[batch_index] inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + inputs = tensors_to_device(inputs, model_device) subgraph_output = forward_function(self.model, **inputs) + subgraph_output = tensors_to_device(subgraph_output, "cpu") for consumed_name in subgraph["consumed_names"]: del intermediates[consumed_name] From 1c45963a8849d35c47f94510c4da62339c7d9f57 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 01:42:47 +0000 Subject: [PATCH 079/285] cleanup --- .../gptq/utils/partitioned_model.py | 161 ++---------------- 1 file changed, 13 insertions(+), 148 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 3f4455e83..9df749a5e 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -12,84 +12,6 @@ from llmcompressor.pytorch.utils.helpers import tensors_to_device -class Model(torch.nn.Module): - def __init__(self, vocab_size=4096, d_model=128, n_heads=1, d_ff=256, dropout=0.1): - super(Model, self).__init__() - - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - assert d_model % n_heads == 0, "d_model must be divisible by n_heads" - - # Embedding layer - self.embedding = torch.nn.Embedding(vocab_size, d_model) - - # Linear transformations for queries, keys, and values - self.query_linear = torch.nn.Linear(d_model, d_model) - self.key_linear = torch.nn.Linear(d_model, d_model) - self.value_linear = torch.nn.Linear(d_model, d_model) - - # Output linear layer to combine heads - self.out_linear = torch.nn.Linear(d_model, d_model) - - # Position-wise feed-forward network - self.feed_forward = torch.nn.Sequential( - torch.nn.Linear(d_model, d_ff), - torch.nn.ReLU(), - torch.nn.Linear(d_ff, d_model) - ) - - # Layer normalization layers - self.norm1 = torch.nn.LayerNorm(d_model) - self.norm2 = torch.nn.LayerNorm(d_model) - - # Dropout layer - self.dropout = torch.nn.Dropout(dropout) - - def scaled_dot_product_attention(self, query, key, value): - # Calculate attention scores - scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) - attn_weights = torch.functional.F.softmax(scores, dim=-1) - output = torch.matmul(attn_weights, value) - return output - - def forward(self, input_ids): - # Apply embedding layer - x = self.embedding(input_ids) # (batch_size, seq_length, d_model) - - batch_size, seq_length, _ = x.size() - - # Linear projections - Q = self.query_linear(x) # (batch_size, seq_length, d_model) - K = self.key_linear(x) # (batch_size, seq_length, d_model) - V = self.value_linear(x) # (batch_size, seq_length, d_model) - - # Split Q, K, V into multiple heads - Q = Q.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) - K = K.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) - V = V.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) - - # Scaled dot-product attention - attn_output = self.scaled_dot_product_attention(Q, K, V) # (batch_size, n_heads, seq_length, head_dim) - - # Concatenate heads - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model) - - # Apply final linear transformation - attn_output = self.out_linear(attn_output) - - # Add & Norm - x = x + self.dropout(attn_output) - x = self.norm1(x) - - # Feed-forward block - ff_output = self.feed_forward(x) - x = x + self.dropout(ff_output) - x = self.norm2(x) - - return BaseModelOutputWithPast(last_hidden_state=x) - - def get_target_nodes(graph: GraphModule, targets: List[str]): target_nodes = [] for node in graph.graph.nodes: @@ -126,13 +48,18 @@ def topological_partition(graph: GraphModule, target_nodes: Set[Node]) -> List[L assert check_assumption(graph.graph) partitions: List[List[Node]] = [[]] - remaining_indegrees = {node: len([node for node in node.all_input_nodes if node.op != "get_attr"]) for node in graph.graph.nodes} - #remaining_indegrees = {node: len((node for node in node.all_input_nodes)) for node in graph.graph.nodes} - partition_index = 0 # global counter, not necessary but ensures partitions are connected + remaining_indegrees = { + node: len([node for node in node.all_input_nodes if node.op != "get_attr"]) + for node in graph.graph.nodes + } + partition_index = 0 # global counter # start with graph input nodes - #queue = deque(node for node in graph.graph.nodes if remaining_indegrees[node] == 0)# and node.op != "get_attr") - queue = deque(node for node in graph.graph.nodes if remaining_indegrees[node] == 0 and node.op != "get_attr") + queue = deque( + node + for node in graph.graph.nodes + if remaining_indegrees[node] == 0 and node.op != "get_attr" + ) while len(queue) > 0: node = queue.popleft() @@ -151,6 +78,9 @@ def topological_partition(graph: GraphModule, target_nodes: Set[Node]) -> List[L if remaining_indegrees[user] == 0: queue.append(user) + # a perfect solution would involve implicitly consolodating partition indices so + # that each node is assigned to the maximum partition possible (in order to delay + # execution as long as possible), but this covers the most costly case (get_attr) for node in graph.graph.nodes: if node.op == "get_attr": user_partitions = [] @@ -226,11 +156,6 @@ def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): return subgraphs -def gptq_compress(name: str, module: torch.nn.Module, inputs: List[torch.Tensor]): - print(f"gptq_compress {name} {module} {inputs.shape}") - pass - - class PartitionedModel: def __init__(self): self.hook_targets = [] @@ -295,63 +220,3 @@ def forward_data(self, dataloader, mask_padding: bool = True): batch_outputs[batch_index] = subgraph_output return batch_outputs - - def forward(self, *args, **kwargs): - # 4. perform compression - intermediates = kwargs.copy() - for subgraph_index, subgraph in enumerate(self.subgraphs): - code = subgraph["code"] - exec(code.src, code.globals) - forward_function = code.globals.get("forward") - - inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} - - if subgraph_index < len(self.subgraphs) - 1: - intermediates.update(forward_function(self.model, **inputs)) - - for consumed_name in subgraph["consumed_names"]: - del intermediates[consumed_name] - else: - return forward_function(self.model, **inputs) - - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def eval(self): - self.model.eval() - - def parameters(self): - return self.model.parameters() - - -if __name__ == "__main__": - use_dummy_model = True - sequence_length = 2048 - - if use_dummy_model: - model = Model() - from torch.fx import symbolic_trace - else: - model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") - from transformers.utils.fx import symbolic_trace - - data_loader = [ - {"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, - #{"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, - #{"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, - ] - - # modifier inits - hooked_model = PartitionedModel() - - # some time after modifier inits but before forward passes - hooked_model.init_forward(model) - - # oneshot/ eval loop - model.eval() - with torch.no_grad(): - for batch in data_loader: - hooked_output = hooked_model.forward(**batch) - model_output = model.forward(**batch) - assert torch.equal(hooked_output["last_hidden_state"], model_output["last_hidden_state"]) From 65b3e5b266a9b617b593357254582baaf54f9306 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 18:59:12 +0000 Subject: [PATCH 080/285] better comments, support sending non-tensors to device --- .../modifiers/quantization/gptq/base.py | 8 +++++--- .../quantization/gptq/utils/partitioned_model.py | 2 +- src/llmcompressor/pytorch/utils/helpers.py | 15 ++++++--------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 3909178d3..9ddb6c5ce 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -200,7 +200,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") + # apply modifier self.register_hooks(state.model) + + # feed data torch.cuda.memory._record_memory_history(max_entries=1_000_000) try: self.calibration_forward(state.model, state.data.calib) @@ -210,14 +213,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool: torch.cuda.memory._record_memory_history(enabled=None) exit(0) + # finalize stuff self.remove_hooks() - - # freeze quantization state.model.apply(freeze_module_quantization) return True - def on_finalize(self, state: "State", **kwargs) -> bool: """ disable the quantization observers used by the OBCQ algorithm @@ -274,6 +275,7 @@ def compress_module( inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") + # TODO: attach as parameters to the module to allow them to be offloaded if module not in self._num_samples: self._hessians[module] = make_empty_hessian(module) self._num_samples[module] = 0 diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 9df749a5e..810f82670 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -177,7 +177,7 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> return True # Treat as leaf, skip tracing inside this module return super().is_leaf_module(module, module_qualified_name) - self.graph: GraphModule = symbolic_trace(model, tracer_cls=CustomTracer) + self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) # 2. identify target nodes all_target_nodes = get_target_nodes(self.graph, targets) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 1a0724e6c..2f902fa43 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -289,8 +289,7 @@ def tensors_to_device( Default function for putting a tensor or collection of tensors to the proper device. Returns the tensor references after being placed on the proper device. - Supported use cases: - - single tensor + Recursive cases: - Dictionary of single tensors - Dictionary of iterable of tensors - Dictionary of dictionary of tensors @@ -303,9 +302,6 @@ def tensors_to_device( ex: 'cpu', 'cuda', 'cuda:1' :return: the tensors or collection of tensors after being placed on the device """ - if isinstance(tensors, Tensor): - return tensors.to(device) - if isinstance(tensors, OrderedDict): return OrderedDict( [(key, tensors_to_device(tens, device)) for key, tens in tensors.items()] @@ -319,10 +315,11 @@ def tensors_to_device( if isinstance(tensors, Iterable): return [tensors_to_device(tens, device) for tens in tensors] - - raise ValueError( - "unrecognized type for tensors given of {}".format(tensors.__class__.__name__) - ) + + if isinstance(tensors, Tensor): + return tensors.to(device) + + return tensors def tensors_to_precision( From 53d0601b0ae147179e3663580d305a5da52c3b9d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 21:12:20 +0000 Subject: [PATCH 081/285] remove breakpoint, fix move_tensors_to_device --- .../modifiers/quantization/gptq/base.py | 14 +++++--------- src/llmcompressor/pytorch/utils/helpers.py | 6 +++--- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 9ddb6c5ce..aef38f6c1 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -245,15 +245,11 @@ def calibration_forward( partitioned_model = PartitionedModel() targets = get_no_split_params(model) partitioned_model.init_forward(model, targets) - breakpoint() - - model.config.use_cache = False - model.eval() - with torch.no_grad(): - try: - partitioned_model.forward_data(dataloader, mask_padding=True) - except EarlyStopException: - pass + + try: + partitioned_model.forward_data(dataloader, mask_padding=True) + except EarlyStopException: + pass def compress_module( self, diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 2f902fa43..db66793b4 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -302,6 +302,9 @@ def tensors_to_device( ex: 'cpu', 'cuda', 'cuda:1' :return: the tensors or collection of tensors after being placed on the device """ + if isinstance(tensors, Tensor): + return tensors.to(device) + if isinstance(tensors, OrderedDict): return OrderedDict( [(key, tensors_to_device(tens, device)) for key, tens in tensors.items()] @@ -316,9 +319,6 @@ def tensors_to_device( if isinstance(tensors, Iterable): return [tensors_to_device(tens, device) for tens in tensors] - if isinstance(tensors, Tensor): - return tensors.to(device) - return tensors From 4da451b98ea9c41bb42053da982ce3c0c63d5a78 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 21:55:00 +0000 Subject: [PATCH 082/285] woof --- .../modifiers/quantization/gptq/base.py | 8 ++++---- src/llmcompressor/pytorch/model_load/helpers.py | 13 +++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index aef38f6c1..6401b7413 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -204,14 +204,14 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.register_hooks(state.model) # feed data - torch.cuda.memory._record_memory_history(max_entries=1_000_000) + #torch.cuda.memory._record_memory_history(max_entries=1_000_000) try: self.calibration_forward(state.model, state.data.calib) finally: pass - torch.cuda.memory._dump_snapshot("partition.pickle") - torch.cuda.memory._record_memory_history(enabled=None) - exit(0) + #torch.cuda.memory._dump_snapshot("partition.pickle") + #torch.cuda.memory._record_memory_history(enabled=None) + #exit(0) # finalize stuff self.remove_hooks() diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 17001a52b..e04476f58 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -220,9 +220,10 @@ def _copy_python_files_from_model_cache(model: Module, save_path: str): import os import shutil - cache_dir = config._name_or_path - for file in os.listdir(cache_dir): - full_file_name = os.path.join(cache_dir, file) - if file.endswith(".py") and os.path.isfile(full_file_name): - logger.debug(f"Transferring {full_file_name} to {save_path}") - shutil.copy(full_file_name, save_path) + if os.path.exists(cache_dir): + cache_dir = config._name_or_path + for file in os.listdir(cache_dir): + full_file_name = os.path.join(cache_dir, file) + if file.endswith(".py") and os.path.isfile(full_file_name): + logger.debug(f"Transferring {full_file_name} to {save_path}") + shutil.copy(full_file_name, save_path) From c77a7fcd62c1e1f1aee9b985ea50e66c99605c84 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 21:56:05 +0000 Subject: [PATCH 083/285] fix thing --- src/llmcompressor/pytorch/model_load/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index e04476f58..c5d179043 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -220,8 +220,8 @@ def _copy_python_files_from_model_cache(model: Module, save_path: str): import os import shutil + cache_dir = config._name_or_path if os.path.exists(cache_dir): - cache_dir = config._name_or_path for file in os.listdir(cache_dir): full_file_name = os.path.join(cache_dir, file) if file.endswith(".py") and os.path.isfile(full_file_name): From 924943496481e698fd7f7027f4ec86ad3bfdcbea Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 23:34:40 +0000 Subject: [PATCH 084/285] remove LayerCompressorMixin, add hooks tests --- .../modifiers/quantization/gptq/base.py | 59 +++--- .../gptq/utils/partitioned_model.py | 60 ++++--- src/llmcompressor/modifiers/utils/hooks.py | 169 ++++-------------- .../modifiers/utils/test_hooks.py | 81 +++++++++ 4 files changed, 182 insertions(+), 187 deletions(-) create mode 100644 tests/llmcompressor/modifiers/utils/test_hooks.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 6401b7413..ab2daa1a3 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,3 +1,4 @@ +from functools import partial import warnings from typing import Any, Dict, List, Optional, Tuple, Union @@ -15,7 +16,7 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import accumulate_hessian, make_empty_hessian, quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.partitioned_model import PartitionedModel from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier -from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin +from llmcompressor.modifiers.utils.hooks import HooksMixin, LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException, run_calibration_forward from llmcompressor.observers.base import Observer from llmcompressor.transformers.finetune.data.data_helpers import ( @@ -37,7 +38,7 @@ __all__ = ["GPTQModifier"] -class GPTQModifier(Modifier, LayerCompressorMixin): +class GPTQModifier(Modifier, HooksMixin): """ Modifier for applying the one-shot OBCQ algorithm to a model @@ -193,6 +194,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ + # initialize quantization modifier if not self.initialized_structure_: self.on_initialize_structure(state, **kwargs) if self._quantization_modifier: @@ -200,22 +202,23 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - # apply modifier - self.register_hooks(state.model) + # register hooks + for name, module in state.model.named_modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + post_hook = partial(self.compress_module, name) + self.register_forward_hook(module, post_hook) + + if "head" in name: + def hook(module: torch.nn.Module, args: Tuple[Any, ...]): + raise EarlyStopException(None, None) # feed data - #torch.cuda.memory._record_memory_history(max_entries=1_000_000) - try: - self.calibration_forward(state.model, state.data.calib) - finally: - pass - #torch.cuda.memory._dump_snapshot("partition.pickle") - #torch.cuda.memory._record_memory_history(enabled=None) - #exit(0) - - # finalize stuff - self.remove_hooks() - state.model.apply(freeze_module_quantization) + dataloader = create_batch_dataloader(dataloader, batch_size=1) + with calibration_forward_context(state.model): + targets = get_no_split_params(state.model) + partitioned_model = PartitionedModel() + partitioned_model.init_forward(state.model, targets) + partitioned_model.forward_data(dataloader, mask_padding=True) return True @@ -228,28 +231,10 @@ def on_finalize(self, state: "State", **kwargs) -> bool: if self._quantization_modifier: self._quantization_modifier.finalize(state, **kwargs) - return True - - def calibration_forward( - self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader - ): - """ - Perform calibration forward pass with one batch whose size is the size - of the dataset - - :param model: model to perform forward pass with - :param dataloader: dataloader containing calibration dataset - """ - dataloader = create_batch_dataloader(dataloader, batch_size=1) - with calibration_forward_context(model): - partitioned_model = PartitionedModel() - targets = get_no_split_params(model) - partitioned_model.init_forward(model, targets) + self.remove_hooks() + state.model.apply(freeze_module_quantization) - try: - partitioned_model.forward_data(dataloader, mask_padding=True) - except EarlyStopException: - pass + return True def compress_module( self, diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 810f82670..549a5f534 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -1,4 +1,5 @@ +import contextlib from typing import Any, Callable, Dict, List, Set import torch @@ -8,7 +9,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils.fx import symbolic_trace, HFTracer -from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException, apply_pad_mask_to_batch from llmcompressor.pytorch.utils.helpers import tensors_to_device @@ -186,7 +188,13 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> partitions: List[List[Node]] = topological_partition(self.graph, all_target_nodes) self.subgraphs: List[GraphModule] = partition_graph(model, partitions) - def forward_data(self, dataloader, mask_padding: bool = True): + def forward_data( + self, + dataloader, + mask_padding: bool = True, + run_twice: bool = True + ): + # TODO: give option to skip lm_head # 4. perform compression model_device = next(self.model.parameters()).device batch_intermediates = [ @@ -200,23 +208,35 @@ def forward_data(self, dataloader, mask_padding: bool = True): exec(code.src, code.globals) forward_function = code.globals.get("forward") - print(f"subgraph_index: {subgraph_index}") - print(batch_intermediates[0].keys()) - - for batch_index in range(len(dataloader)): - intermediates = batch_intermediates[batch_index] - - inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} - inputs = tensors_to_device(inputs, model_device) - subgraph_output = forward_function(self.model, **inputs) - subgraph_output = tensors_to_device(subgraph_output, "cpu") - - for consumed_name in subgraph["consumed_names"]: - del intermediates[consumed_name] - - if subgraph_index < len(self.subgraphs) - 1: - intermediates.update(subgraph_output) - else: - batch_outputs[batch_index] = subgraph_output + if run_twice: + for batch_index in range(len(dataloader)): + intermediates = batch_intermediates[batch_index] + inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + inputs = tensors_to_device(inputs, model_device) + try: + forward_function(self.model, **inputs) + except EarlyStopException: + pass + + with HooksMixin.disable_hooks() if run_twice else contextlib.nullcontext(): + for batch_index in range(len(dataloader)): + intermediates = batch_intermediates[batch_index] + + inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + inputs = tensors_to_device(inputs, model_device) + try: + subgraph_output = forward_function(self.model, **inputs) + except EarlyStopException: + subgraph_output = None + pass + subgraph_output = tensors_to_device(subgraph_output, "cpu") + + for consumed_name in subgraph["consumed_names"]: + del intermediates[consumed_name] + + if subgraph_index < len(self.subgraphs) - 1: + intermediates.update(subgraph_output) + else: + batch_outputs[batch_index] = subgraph_output return batch_outputs diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 14fe90e50..93ab0e64b 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,7 +1,7 @@ import contextlib from abc import abstractmethod -from functools import partial -from typing import Any, Callable, ClassVar, Dict, List, Set, Tuple, Union +from functools import wraps +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Set, Tuple, Union import torch from loguru import logger @@ -18,33 +18,25 @@ class HooksMixin(BaseModel): - """ " + """ Class to manage the registration, disabling, and removal of hooks. Registering and removing hooks should be handled by modifier classes which inherit from this mixin, while disabling hooks should disable all hooks across modifiers. - Modifiers which implement hooks should use the @HooksMixin.hook decorator - Modifiers must pass registered hooks handles to self.register_hook() and must - remove hooks when finished using self.remove_hooks() + Modifiers which implement hooks should register them using + `self.register_..._hook(module, hook)` rather than the usual + `module.register_..._hook(hook)`. Modifiers should remove hooks with + `self.remove_hooks()` Lifecycle: - - Modifier.register_hooks(model) - - model.forward() - - Modifier.remove_hooks() + - self = Modifier(HooksMixin)(...) + - self.register_forward_hook(module, hook) + - with HooksMixin.disable_hooks(): model.forward() + - self.remove_hooks() """ - _HOOKS_DISABLED: ClassVar[bool] = False - _hooks: List[RemovableHandle] = [] - - @classmethod - def hook(cls, func: Callable[[Any], Any]): - def wrapped(*args, **kwargs): - if cls._HOOKS_DISABLED: - return - - return func(*args, **kwargs) - - return wrapped + _HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin + _hooks: List[RemovableHandle] = [] # attached to local subclasses @classmethod @contextlib.contextmanager @@ -59,13 +51,21 @@ def disable_hooks(cls): finally: cls._HOOKS_DISABLED = False - def register_hook(self, handle: RemovableHandle): - """ - Usage: self.register_hook(module.register_forward_hook(...)) + def register_forward_pre_hook( + self, + module: torch.nn.Module, + func: Callable[[Any], Any], + **kwargs, + ): + self._register_hook("register_forward_pre_hook", module, func, **kwargs) - :param handle: handle of added hook - """ - self._hooks.append(handle) + def register_forward_hook( + self, + module: torch.nn.Module, + func: Callable[[Any], Any], + **kwargs, + ): + self._register_hook("register_forward_hook", module, func, **kwargs) def remove_hooks(self): """ @@ -74,110 +74,19 @@ def remove_hooks(self): for hook in self._hooks: hook.remove() - -class LayerCompressorMixin(HooksMixin): - """ - Apply a given compression function to a model during the model's calibration - forward pass - - Lifecycle: - - QuantizationModifier.initialize(model) - - Modifier.register_hooks(model) - - model.forward() - - compress_fn(name, target_module, args) - - Modifier.remove_hooks() - - :ivar true_sequential: Used to control the granularity of compression updates - through the forward pass. Set to True to use the weight-compressed outputs - of each module, set to False to use the weight-compressed outputs of each - layer (transformer block), defaults to False - :ivar sequential_targets: list of layer names to compress during GPTQ, or - '__ALL__' to compress every layer in the model - :ivar compresss_module: Function to be called on target modules - """ - - sequential_targets: bool - - _layer_index = 0 - _num_layers = 0 - - @abstractmethod - def compress_module( + def _register_hook( self, - name: str, + register_func_name: str, module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], + func: Callable[[Any], Any], + **kwargs, ): - raise NotImplementedError() - - def register_hooks(self, model: torch.nn.Module): - # find layers (used for printing even if true_sequential=True) - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - # sequential_targets = self.sequential_targets - # if sequential_targets is None: - # sequential_targets = get_no_split_params(model) - # layers = get_layers(sequential_targets, model) - # self._num_layers = len(layers) - - for name, module in model.named_modules(): - if getattr_chain(module, "quantization_scheme.weights", None) is not None: - pre_hook = partial(self.target_pre_forward, name) - #post_hook = partial(self.target_post_forward, name) - self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) - #self.register_hook(module.register_forward_hook(post_hook)) - - if "head" in name: - def hook(module: torch.nn.Module, args: Tuple[Any, ...]): - raise EarlyStopException(None, None) - - # self.register_hook(module.register_forward_pre_hook(hook, with_kwargs=False)) - - # if name in layers.keys(): - # pre_hook = partial(self.layer_pre_forward, name) - # post_hook = partial(self.layer_post_forward, name) - # self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) - # self.register_hook( - # module.register_forward_hook(post_hook, with_kwargs=True) - # ) + @wraps(func) + def wrapped_hook(*args, **kwargs): + if HooksMixin._HOOKS_DISABLED: + return + + return func(*args, **kwargs) - - @HooksMixin.hook - def target_pre_forward( - self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] - ): - # compress - self.compress_module(name, module, args) - - @HooksMixin.hook - def target_post_forward( - self, - name: str, - module: torch.nn.Module, - args: Tuple[Any, ...], - output: Tuple[Any, ...], - ): - print(f"post {name}") - - - @HooksMixin.hook - def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs): - logger.info( - f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" - ) - - - @HooksMixin.hook - def layer_post_forward( - self, - name: str, - layer: torch.nn.Module, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - output: Tuple[Any, ...], - ): - print(f"post {name}") - self._layer_index += 1 - - return output + handle = getattr(module, register_func_name)(wrapped_hook, **kwargs) + self._hooks.append(handle) \ No newline at end of file diff --git a/tests/llmcompressor/modifiers/utils/test_hooks.py b/tests/llmcompressor/modifiers/utils/test_hooks.py new file mode 100644 index 000000000..ef6f00f52 --- /dev/null +++ b/tests/llmcompressor/modifiers/utils/test_hooks.py @@ -0,0 +1,81 @@ +import torch + +from llmcompressor.modifiers.modifier import Modifier +from llmcompressor.modifiers.utils.hooks import HooksMixin + + +class DummyModel(torch.nn.Module): + def __init__(self): + super(DummyModel, self).__init__() + + + self.linear1 = torch.nn.Linear(1, 2) + self.linear2 = torch.nn.Linear(2, 3) + self.linear3 = torch.nn.Linear(3, 1) + self.dummy_inputs = torch.tensor([0.0]) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + + return x + +class DummyMod(HooksMixin): + hook_called: bool = False + + def hook(self, *args, **kwargs): + self.hook_called = True + +class ModA(DummyMod): + pass + + +class ModB(DummyMod): + pass + + +def test_register_hook(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_forward_hook(model.linear1, mod_a.hook) + + mod_b = ModB() + mod_b.register_forward_pre_hook(model.linear2, mod_b.hook) + + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called + + +def test_remove_hooks(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_forward_hook(model.linear1, mod_a.hook) + + mod_b = ModB() + mod_b.register_forward_pre_hook(model.linear2, mod_b.hook) + mod_b.remove_hooks() + + model(model.dummy_inputs) + assert mod_a.hook_called and not mod_b.hook_called + + +def test_disable_hooks(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_forward_hook(model.linear1, mod_a.hook) + + mod_b = ModB() + mod_b.register_forward_pre_hook(model.linear2, mod_b.hook) + + with HooksMixin.disable_hooks(): + model(model.dummy_inputs) + assert not mod_a.hook_called and not mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called \ No newline at end of file From 2690e10dab76eef2001de15f83077a800c344d14 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Nov 2024 23:49:40 +0000 Subject: [PATCH 085/285] Implement HooksMixin Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/utils/hooks.py | 64 +++++++++++++++ .../modifiers/utils/test_hooks.py | 81 +++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 src/llmcompressor/modifiers/utils/hooks.py create mode 100644 tests/llmcompressor/modifiers/utils/test_hooks.py diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py new file mode 100644 index 000000000..ae4d82456 --- /dev/null +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -0,0 +1,64 @@ +import contextlib +from functools import wraps +from typing import Any, Callable, ClassVar, List + +import torch +from pydantic import BaseModel +from torch.utils.hooks import RemovableHandle + +__all__ = ["HooksMixin"] + + +class HooksMixin(BaseModel): + """ + Mixin to manage hook registration, disabling, and removal. + Modifiers should use `self.register_hook(module, hook, hook_type)` + for hook registration and `self.remove_hooks()` for removal. + + Modifiers which implement hooks should register them using + `self.register_..._hook(module, hook)` rather than the usual + `module.register_..._hook(hook)`. Modifiers should remove hooks with + `self.remove_hooks()` + + Lifecycle: + - modifier.register_forward_hook(module, hook) + - with HooksMixin.disable_hooks(): model.forward() + - modifier.remove_hooks() + """ + + _HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin + _hooks: List[RemovableHandle] = [] # attached to local subclasses + + @classmethod + @contextlib.contextmanager + def disable_hooks(cls): + """Disable all hooks across all modifiers""" + try: + cls._HOOKS_DISABLED = True + yield + finally: + cls._HOOKS_DISABLED = False + + def register_hook( + self, + module: torch.nn.Module, + func: Callable[[Any], Any], + hook_type: str, + **kwargs, + ): + @wraps(func) + def wrapped_hook(*args, **kwargs): + if HooksMixin._HOOKS_DISABLED: + return + + return func(*args, **kwargs) + + handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) + self._hooks.append(handle) + + def remove_hooks(self): + """ + Remove all hooks belonging to a modifier + """ + for hook in self._hooks: + hook.remove() diff --git a/tests/llmcompressor/modifiers/utils/test_hooks.py b/tests/llmcompressor/modifiers/utils/test_hooks.py new file mode 100644 index 000000000..79ab3a1b4 --- /dev/null +++ b/tests/llmcompressor/modifiers/utils/test_hooks.py @@ -0,0 +1,81 @@ +import torch + +from llmcompressor.modifiers.utils.hooks import HooksMixin + + +class DummyModel(torch.nn.Module): + def __init__(self): + super(DummyModel, self).__init__() + + self.linear1 = torch.nn.Linear(1, 2) + self.linear2 = torch.nn.Linear(2, 3) + self.linear3 = torch.nn.Linear(3, 1) + self.dummy_inputs = torch.tensor([0.0]) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + + return x + + +class DummyMod(HooksMixin): + hook_called: bool = False + + def hook(self, *args, **kwargs): + self.hook_called = True + + +class ModA(DummyMod): + pass + + +class ModB(DummyMod): + pass + + +def test_register_hook(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called + + +def test_remove_hooks(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + mod_b.remove_hooks() + + model(model.dummy_inputs) + assert mod_a.hook_called and not mod_b.hook_called + + +def test_disable_hooks(): + model = DummyModel() + + mod_a = ModA() + mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + + with HooksMixin.disable_hooks(): + model(model.dummy_inputs) + assert not mod_a.hook_called and not mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called From 004f5c75fe9b3cf4ea703a3685d826acc67d3ea0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 00:05:27 +0000 Subject: [PATCH 086/285] add docstring Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/utils/hooks.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index ae4d82456..c73cf975f 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -42,23 +42,33 @@ def disable_hooks(cls): def register_hook( self, module: torch.nn.Module, - func: Callable[[Any], Any], + hook: Callable[[Any], Any], hook_type: str, **kwargs, ): - @wraps(func) + """ + Registers a hook on a specified module with the option to disable it with + HooksMixin.disable_hooks + + :param module: the module on which the hook should be registered + :param hook: the hook to register + :param hook_type: the type of hook to register corresponding to the + `register_{hook_type}_hook` attribute on torch.nn.Module. + Ex. "forward", "forward_pre", "full_backward", "state_dict_post" + :param kwargs: keyword arguments to pass to register hook method + """ + + @wraps(hook) def wrapped_hook(*args, **kwargs): if HooksMixin._HOOKS_DISABLED: return - return func(*args, **kwargs) + return hook(*args, **kwargs) handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) self._hooks.append(handle) def remove_hooks(self): - """ - Remove all hooks belonging to a modifier - """ + """Remove all hooks belonging to a modifier""" for hook in self._hooks: hook.remove() From d3058f0363f60961fbccbd00b62806cbb0737e69 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 00:11:24 +0000 Subject: [PATCH 087/285] integrate with smoothquant Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/smoothquant/base.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 7487d0609..3c05fea5b 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -11,6 +11,7 @@ get_layer_mappings_from_architecture, handle_mapping_resolution_errors, ) +from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.pytorch.module import get_layers, get_matching_layer @@ -52,7 +53,7 @@ class SmoothQuantMapping: balance_layers: List[Module] -class SmoothQuantModifier(Modifier): +class SmoothQuantModifier(Modifier, HooksMixin): """ Implements the SmoothQuant algorithm from https://arxiv.org/abs/2211.10438. This modifier performs a channel-wise smoothing of outliers in activations, making them @@ -99,7 +100,6 @@ class SmoothQuantModifier(Modifier): num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None - hooks_: Optional[List] = None resolved_mappings_: Optional[List] = None scales_: Optional[Dict] = None @@ -127,7 +127,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.scales_ = {} calibration_dataloader = state.data.calib - self.hooks_ = [] self._setup_scale_hooks() self._calibrate(state.model, calibration_dataloader) @@ -228,7 +227,7 @@ def hook_fn(module, inp, out): for mapping in self.resolved_mappings_: name = mapping.smooth_name layer = mapping.smooth_layer - self.hooks_.append(layer.register_forward_hook(create_hook_fn(name))) + self.register_hook(layer, create_hook_fn(name), "forward") @torch.no_grad() def _calibrate(self, model: Module, calibration_dataloader: List): @@ -255,9 +254,7 @@ def _calibrate(self, model: Module, calibration_dataloader: List): ) # remove the hooks now that we are done calibrating - for hook in self.hooks_: - hook.remove() - del self.hooks_ + self.remove_hooks() @torch.no_grad() def _apply_smoothing(self, model: Module): From 1ae3ce0ee5002810ae700854d4ef78b56d065c23 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 00:40:16 +0000 Subject: [PATCH 088/285] integrate with QuantizationModifier Signed-off-by: Kyle Sayers --- .../modifiers/quantization/calibration.py | 70 ++++++++----------- .../quantization/quantization/base.py | 43 ++++-------- src/llmcompressor/modifiers/utils/hooks.py | 2 + 3 files changed, 44 insertions(+), 71 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 0c9508530..ee4ce171e 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional, Tuple import torch from compressed_tensors.quantization import QuantizationStatus, is_attention_module @@ -146,71 +146,57 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): ) -def calibrate_input_hook(): +def calibrate_input_hook(module: Module, args: Any): """ Hook to calibrate input activations. Will call the observers to update the scales/zp before applying input QDQ in the module's forward pass. """ + args = args[0] if isinstance(args, tuple) else args + calibrate_activations(module, value=args, base_name="input") - def hook_fn(module: Module, inp): - inp = inp[0] if isinstance(inp, tuple) else inp - calibrate_activations(module, value=inp, base_name="input") - return hook_fn - - -def calibrate_output_hook(): +def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): """ Hook to calibrate output activations. Will call the observers to update the scales/zp before applying output QDQ. """ - - def hook_fn(module: Module, inp, output: torch.Tensor): - calibrate_activations( - module, - value=output, - base_name="output", - ) - output = forward_quantize( - module=module, - value=output, - base_name="output", - args=module.quantization_scheme.output_activations, - ) - return output - - return hook_fn + calibrate_activations( + module, + value=output, + base_name="output", + ) + output = forward_quantize( + module=module, + value=output, + base_name="output", + args=module.quantization_scheme.output_activations, + ) + return output -def calibrate_kv_cache_input_hook(): +def calibrate_kv_cache_input_hook( + module: Module, args: Any, kwargs: Dict[str, Any] +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ Hook to update inputs to attention layers when running kv_cache quantization. Will update the passed in kv_cache to singleton QuantizedKVParameterCache. """ + kv_cache = getattr(module, "kv_cache") + kwargs["past_key_value"] = kv_cache + kwargs["use_cache"] = False + return args, kwargs - def hook_fn(module: Module, args, kwargs): - kv_cache = getattr(module, "kv_cache") - kwargs["past_key_value"] = kv_cache - kwargs["use_cache"] = False - return args, kwargs - - return hook_fn - -def calibrate_kv_cache_output_hook(): +def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): """ Hook to update k_scale and v_scale parameters when running kv_cache quantization. """ - - def hook_fn(module: Module, inpt, output: torch.Tensor): - kv_cache = getattr(module, "kv_cache") - update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale") - update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale") - - return hook_fn + kv_cache = getattr(module, "kv_cache") + update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale") + update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale") def set_unset_kv_cache(module: Module): diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index df2ee1ce1..0f9e337e8 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -27,6 +27,7 @@ set_unset_kv_cache, update_weight_zp_scale, ) +from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import ( is_moe_model, run_calibration_forward, @@ -36,7 +37,7 @@ __all__ = ["QuantizationModifier"] -class QuantizationModifier(Modifier): +class QuantizationModifier(Modifier, HooksMixin): """ Enables post training quantization (PTQ) and quantization aware training (QAT) for a given module or its submodules. After calibration (PTQ) or the start epoch (QAT), @@ -81,7 +82,6 @@ class QuantizationModifier(Modifier): calibration_dataloader_: Any = None calibration_function_: Any = None - calibration_hooks_: List = None def on_initialize(self, state: State, **kwargs) -> bool: if self.end and self.end != -1: @@ -101,7 +101,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: self._check_calibration_data(config) module.apply(update_weight_zp_scale) module.apply(apply_calibration_status) - self.calibration_hooks_ = [] self._calibrate_if_possible(module) self._check_token_distribution( module, threshold=kwargs.get("min_tokens_per_module") @@ -230,15 +229,12 @@ def _calibrate_if_possible(self, module: Module): register_calibration_hooks(): if input activation and not dynamic quant (used to call observers before intput QDQ): - - pre_hook_handle = module.register_forward_pre_hook(calibrate_input_hook()) + - pre_hook := calibrate_input_hook if output activation and not dynamic quant (used to call observers before output QDQ): - - post_hook_handle = module.register_forward_hook(calibrate_kv_cache_output_hook()) + - post_hook := calibrate_kv_cache_output_hook if kv_cache quantization (used to set kv_cache to QuantizedKVParameterCache and update k_scale/v_scale) - - pre_hook_handle = module.register_forward_pre_hook(calibrate_kv_cache_input_hook(), with_kwargs=True) - - post_hook_handle = module.register_forward_hook(calibrate_kv_cache_output_hook()) - - self.calibration_hooks.append(pre_hook_handle) - self.calibration_hooks.append(post_hook_handle) + - pre_hook := calibrate_kv_cache_input_hook + - post_hook := calibrate_kv_cache_output_hook self._calibrate(module) # run forward pass through model using calibration data set_unset_kv_cache() # remove kv_cache objects attached to attention layers @@ -267,8 +263,7 @@ def _calibrate_if_possible(self, module: Module): module.apply(self.register_calibration_hooks) self._calibrate(module) module.apply(set_unset_kv_cache) - for h in self.calibration_hooks_: - h.remove() + self.remove_hooks() def register_calibration_hooks(self, module: Module): """ @@ -278,8 +273,6 @@ def register_calibration_hooks(self, module: Module): if not quantization_scheme: return - pre_hook_handle = None - post_hook_handle = None is_attention_module_ = is_attention_module(module) input_quant = quantization_scheme.input_activations output_quant = quantization_scheme.output_activations @@ -290,27 +283,19 @@ def register_calibration_hooks(self, module: Module): # Calibrate inputs if an input_quant is provided and not running dynamic quant if calibrate_inputs: - pre_hook_handle = module.register_forward_pre_hook(calibrate_input_hook()) + self.register_hook(module, calibrate_input_hook, "forward_pre") if output_quant: # hooks for attn modules if running kv_cache quant if is_attention_module_: - pre_hook_handle = module.register_forward_pre_hook( - calibrate_kv_cache_input_hook(), with_kwargs=True - ) - post_hook_handle = module.register_forward_hook( - calibrate_kv_cache_output_hook() - ) + pre_hook = calibrate_kv_cache_input_hook + self.register_hook(module, pre_hook, "forward_pre", with_kwargs=True) + + self.register_hook(module, calibrate_kv_cache_output_hook, "forward") + # hooks for output quant if not running dynamic quant elif not output_quant.dynamic: - post_hook_handle = module.register_forward_hook(calibrate_output_hook()) - - if pre_hook_handle: - logger.debug(f"Add {pre_hook_handle} for calibration") - self.calibration_hooks_.append(pre_hook_handle) - if post_hook_handle: - logger.debug(f"Add {post_hook_handle} for calibration") - self.calibration_hooks_.append(post_hook_handle) + self.register_hook(module, calibrate_output_hook, "forward") def _calibrate(self, module: Module): class_name = self.__class__.__name__.replace("PyTorch", "") diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index c73cf975f..de4e42898 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -3,6 +3,7 @@ from typing import Any, Callable, ClassVar, List import torch +from loguru import logger from pydantic import BaseModel from torch.utils.hooks import RemovableHandle @@ -67,6 +68,7 @@ def wrapped_hook(*args, **kwargs): handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) self._hooks.append(handle) + logger.debug(f"Added {handle} for {self}") def remove_hooks(self): """Remove all hooks belonging to a modifier""" From fc2488f8556ef268651c5c6dac69b3cf26bb3642 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 20:34:51 +0000 Subject: [PATCH 089/285] update hooks in tests Signed-off-by: Kyle Sayers --- tests/llmcompressor/modifiers/calibration/test_kv_cache.py | 4 ++-- tests/llmcompressor/observers/test_helpers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py index d9fca8fa2..25b8468f4 100644 --- a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py +++ b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py @@ -54,9 +54,9 @@ def _prep_for_calibration(module: torch.nn.Module): if is_attention_module(module): module.register_forward_pre_hook( - calibrate_kv_cache_input_hook(), with_kwargs=True + calibrate_kv_cache_input_hook, with_kwargs=True ) - module.register_forward_hook(calibrate_kv_cache_output_hook()) + module.register_forward_hook(calibrate_kv_cache_output_hook) module.quantization_status = QuantizationStatus.CALIBRATION diff --git a/tests/llmcompressor/observers/test_helpers.py b/tests/llmcompressor/observers/test_helpers.py index 6668223f7..527176019 100644 --- a/tests/llmcompressor/observers/test_helpers.py +++ b/tests/llmcompressor/observers/test_helpers.py @@ -32,7 +32,7 @@ def _prep_for_input_quant_calibration(module: torch.nn.Module): if not quantization_scheme: return - module.register_forward_pre_hook(calibrate_input_hook()) + module.register_forward_pre_hook(calibrate_input_hook) module.quantization_status = QuantizationStatus.CALIBRATION From d0dc8076dee8912975e4a20e79a36a924440013a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 21:28:50 +0000 Subject: [PATCH 090/285] integrate with wanda Signed-off-by: Kyle Sayers --- .../modifiers/pruning/wanda/base.py | 76 +++++++++---------- src/llmcompressor/modifiers/utils/hooks.py | 6 +- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index f056ee1ae..4e6784bea 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -1,3 +1,4 @@ +import functools from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -9,6 +10,7 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.pruning.wanda.utils.wanda_wrapper import WandaWrapper +from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.pytorch.module import ( @@ -20,7 +22,7 @@ __all__ = ["WandaPruningModifier"] -class WandaPruningModifier(Modifier): +class WandaPruningModifier(Modifier, HooksMixin): """ Modifier for applying the one-shot WANDA algorithm to a model from the paper: https://arxiv.org/abs/2306.11695 @@ -121,7 +123,8 @@ def initialize_compression( "Inferring layer-wise sparsities from " f"{len(dataloader) if dataloader else 0} calibration samples..." ) - self.sparsity = self._infer_layer_sparsity(dataloader) + activations = self._get_activations(dataloader) + self.sparsity = self._infer_layer_sparsity(activations) self._validate_layerwise_sparsity() for idx, (name, layer) in enumerate(self.compressible_layers_.items()): @@ -224,19 +227,17 @@ def _infer_mask_block_size(self): self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) - def _infer_layer_sparsity(self, calibration_dataloader): - acts = _get_activations(self.model, calibration_dataloader) + def _infer_layer_sparsity(self, activations): wanda = {} for name, layer in self.compressible_layers_.items(): prunable_layers = get_prunable_layers(layer) z = [ - m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) + m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) for n, m in prunable_layers.items() ] wanda[name] = torch.cat([item.flatten().cpu() for item in z]) - acts = None - del acts + del activations torch.cuda.empty_cache() outlier_ratios = {} @@ -268,36 +269,35 @@ def _infer_layer_sparsity(self, calibration_dataloader): logger.info(f"Sparsity for {k}: {sparsities[k]}") return sparsities + @torch.no_grad() + def _get_activations(self, data_loader, nsamples=128): + self.model.eval() + acts = {} + + def save_acts(module, input, name): + if isinstance(input, tuple): + input = input[0] + if name not in acts: + acts[name] = ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + else: + acts[name] += ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + + for name, mod in self.model.named_modules(): + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: + self.register_hook( + mod, functools.partial(save_acts, name=name), "forward_pre" + ) + device = next(self.model.parameters()).device + for batch in tqdm(data_loader): + batch = {k: v.to(device) for k, v in batch.items()} + self.model(**batch) + batch = None + torch.cuda.empty_cache() -@torch.no_grad() -def _get_activations(model, data_loader, nsamples=128): - import functools - - model.eval() - acts = {} - - def save_acts(module, input, name): - if isinstance(input, tuple): - input = input[0] - if name not in acts: - acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - else: - acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - - hooks = [] - for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: - hooks.append( - mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) - ) - device = next(model.parameters()).device - for batch in tqdm(data_loader): - batch = {k: v.to(device) for k, v in batch.items()} - model(**batch) - batch = None - torch.cuda.empty_cache() - - for h in hooks: - h.remove() + self.remove_hooks() - return acts + return acts diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index de4e42898..39134e273 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -46,7 +46,7 @@ def register_hook( hook: Callable[[Any], Any], hook_type: str, **kwargs, - ): + ) -> RemovableHandle: """ Registers a hook on a specified module with the option to disable it with HooksMixin.disable_hooks @@ -68,7 +68,9 @@ def wrapped_hook(*args, **kwargs): handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) self._hooks.append(handle) - logger.debug(f"Added {handle} for {self}") + logger.debug(f"{self} added {handle}") + + return handle def remove_hooks(self): """Remove all hooks belonging to a modifier""" From 55f69d65382626cf06a1356d4adf9633593a3cae Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 21:48:28 +0000 Subject: [PATCH 091/285] integrate with magnitude and constant Signed-off-by: Kyle Sayers --- .../pruning/utils/pytorch/layer_mask.py | 21 +++++-------------- src/llmcompressor/modifiers/utils/hooks.py | 20 +++++++++++------- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py b/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py index 3ada8c7fb..d59b4563b 100644 --- a/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py +++ b/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py @@ -2,11 +2,10 @@ from typing import Dict import torch -from pydantic import BaseModel from torch.nn import Parameter -from torch.utils.hooks import RemovableHandle from llmcompressor.core import ModelParameterizedLayer +from llmcompressor.modifiers.utils.hooks import HooksMixin __all__ = ["LayerParamMasking", "param_mask_name"] @@ -39,11 +38,9 @@ class ParameterizedLayerMaskSettings: use_hooks: bool = False -class LayerParamMasking(BaseModel): +class LayerParamMasking(HooksMixin): _mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {} _masked_layer_params: Dict[str, ModelParameterizedLayer] = {} - _forward_hooks: Dict[str, RemovableHandle] = {} - _backward_hooks: Dict[str, RemovableHandle] = {} enabled_: bool = False def add_mask( @@ -100,12 +97,8 @@ def _backward_hook_fn(gradients): return gradients - self._forward_hooks[layer_param_name] = ( - parameterized_layer.layer.register_forward_hook(_forward_hook_fn) - ) - self._backward_hooks[layer_param_name] = ( - parameterized_layer.param.register_hook(_backward_hook_fn) - ) + self.register_hook(parameterized_layer.layer, _forward_hook_fn, "forward") + self.register_hook(parameterized_layer.param, _backward_hook_fn, "") def update_mask( self, @@ -131,11 +124,7 @@ def remove_mask(self, layer_param_name: str): del self._mask_settings[layer_param_name] if mask_settings.use_hooks: - self._forward_hooks[layer_param_name].remove() - self._backward_hooks[layer_param_name].remove() - - del self._forward_hooks[layer_param_name] - del self._backward_hooks[layer_param_name] + self.remove_hooks() def apply_mask_weight(self, layer_param_name: str): if not self.enabled_: diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 39134e273..44e9c13f9 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,6 +1,6 @@ import contextlib from functools import wraps -from typing import Any, Callable, ClassVar, List +from typing import Any, Callable, ClassVar, List, Union import torch from loguru import logger @@ -19,7 +19,9 @@ class HooksMixin(BaseModel): Modifiers which implement hooks should register them using `self.register_..._hook(module, hook)` rather than the usual `module.register_..._hook(hook)`. Modifiers should remove hooks with - `self.remove_hooks()` + `self.remove_hooks()`. + + Hooks can be applied to modules or parameters Lifecycle: - modifier.register_forward_hook(module, hook) @@ -42,20 +44,20 @@ def disable_hooks(cls): def register_hook( self, - module: torch.nn.Module, + target: Union[torch.nn.Module, torch.nn.Parameter], hook: Callable[[Any], Any], hook_type: str, **kwargs, ) -> RemovableHandle: """ - Registers a hook on a specified module with the option to disable it with - HooksMixin.disable_hooks + Registers a hook on a specified module/parameter with the option to disable it + with HooksMixin.disable_hooks() - :param module: the module on which the hook should be registered + :param target: the module or parameter on which the hook should be registered :param hook: the hook to register :param hook_type: the type of hook to register corresponding to the `register_{hook_type}_hook` attribute on torch.nn.Module. - Ex. "forward", "forward_pre", "full_backward", "state_dict_post" + Ex. "forward", "forward_pre", "full_backward", "state_dict_post", "" :param kwargs: keyword arguments to pass to register hook method """ @@ -66,7 +68,7 @@ def wrapped_hook(*args, **kwargs): return hook(*args, **kwargs) - handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) + handle = getattr(target, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) self._hooks.append(handle) logger.debug(f"{self} added {handle}") @@ -76,3 +78,5 @@ def remove_hooks(self): """Remove all hooks belonging to a modifier""" for hook in self._hooks: hook.remove() + + self._hooks = [] From 59ffe447ac319e5cd5d5c5d0354f22fad7ed55e8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 21:56:33 +0000 Subject: [PATCH 092/285] integrate with SparseGPTModifier Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/obcq/base.py | 72 +++++++++---------- .../modifiers/pruning/wanda/base.py | 7 +- 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index 3da0e3d0c..9cf0ff331 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -130,7 +131,8 @@ def initialize_compression( "Inferring layer-wise sparsities from " f"{len(dataloader)} calibration samples..." ) - self.sparsity = self._infer_layer_sparsity(dataloader) + activations = self._get_activations(dataloader) + self.sparsity = self._infer_layer_sparsity(activations) self._validate_layerwise_sparsity() for idx, (name, layer) in enumerate(self.compressible_layers_.items()): @@ -254,19 +256,17 @@ def _infer_mask_block_size(self): self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) - def _infer_layer_sparsity(self, calibration_dataloader): - acts = _get_activations(self.model, calibration_dataloader) + def _infer_layer_sparsity(self, activations): sparsegpt_groups = {} for name, layer in self.compressible_layers_.items(): prunable_layers = get_prunable_layers(layer) z = [ - m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) + m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) for n, m in prunable_layers.items() ] sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z]) - acts = None - del acts + del activations torch.cuda.empty_cache() outlier_ratios = {} @@ -300,36 +300,34 @@ def _infer_layer_sparsity(self, calibration_dataloader): logger.info(f"Sparsity for {k}: {sparsities[k]}") return sparsities + @torch.no_grad() + def _get_activations(self, data_loader, nsamples=128): + self.model.eval() + acts = {} + + def save_acts(module, input, name): + if isinstance(input, tuple): + input = input[0] + if name not in acts: + acts[name] = ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + else: + acts[name] += ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + + for name, mod in self.model.named_modules(): + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: + self.register_hook(mod, partial(save_acts, name=name), "forward_pre") + + device = next(self.model.parameters()).device + for batch in tqdm(data_loader): + batch = {k: v.to(device) for k, v in batch.items()} + self.model(**batch) + batch = None + torch.cuda.empty_cache() -@torch.no_grad() -def _get_activations(model, data_loader, nsamples=128): - import functools - - model.eval() - acts = {} - - def save_acts(module, input, name): - if isinstance(input, tuple): - input = input[0] - if name not in acts: - acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - else: - acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - - hooks = [] - for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: - hooks.append( - mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) - ) - device = next(model.parameters()).device - for batch in tqdm(data_loader): - batch = {k: v.to(device) for k, v in batch.items()} - model(**batch) - batch = None - torch.cuda.empty_cache() - - for h in hooks: - h.remove() + self.remove_hooks() - return acts + return acts diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index 4e6784bea..0a399db9b 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -1,4 +1,4 @@ -import functools +from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -288,9 +288,8 @@ def save_acts(module, input, name): for name, mod in self.model.named_modules(): if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: - self.register_hook( - mod, functools.partial(save_acts, name=name), "forward_pre" - ) + self.register_hook(mod, partial(save_acts, name=name), "forward_pre") + device = next(self.model.parameters()).device for batch in tqdm(data_loader): batch = {k: v.to(device) for k, v in batch.items()} From 21fe61b3e9fa54efab2c0e5ad935b26e1a94b96d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 22:07:20 +0000 Subject: [PATCH 093/285] add hooksmixin to modifier Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/modifier.py | 7 +++---- src/llmcompressor/modifiers/pruning/wanda/base.py | 3 +-- .../modifiers/quantization/quantization/base.py | 3 +-- src/llmcompressor/modifiers/smoothquant/base.py | 3 +-- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/llmcompressor/modifiers/modifier.py b/src/llmcompressor/modifiers/modifier.py index 494f8bdfc..65b4a4029 100644 --- a/src/llmcompressor/modifiers/modifier.py +++ b/src/llmcompressor/modifiers/modifier.py @@ -1,16 +1,15 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Optional -from pydantic import BaseModel - from llmcompressor.core.events import Event, EventType from llmcompressor.core.state import State from llmcompressor.modifiers.interface import ModifierInterface +from llmcompressor.modifiers.utils.hooks import HooksMixin __all__ = ["Modifier"] -class Modifier(BaseModel, ModifierInterface, ABC): +class Modifier(ModifierInterface, HooksMixin): """ A base class for all modifiers to inherit from. Modifiers are used to modify the training process for a model. diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index 0a399db9b..1881a347c 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -10,7 +10,6 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.pruning.wanda.utils.wanda_wrapper import WandaWrapper -from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.pytorch.module import ( @@ -22,7 +21,7 @@ __all__ = ["WandaPruningModifier"] -class WandaPruningModifier(Modifier, HooksMixin): +class WandaPruningModifier(Modifier): """ Modifier for applying the one-shot WANDA algorithm to a model from the paper: https://arxiv.org/abs/2306.11695 diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 0f9e337e8..67da67bc0 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -27,7 +27,6 @@ set_unset_kv_cache, update_weight_zp_scale, ) -from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import ( is_moe_model, run_calibration_forward, @@ -37,7 +36,7 @@ __all__ = ["QuantizationModifier"] -class QuantizationModifier(Modifier, HooksMixin): +class QuantizationModifier(Modifier): """ Enables post training quantization (PTQ) and quantization aware training (QAT) for a given module or its submodules. After calibration (PTQ) or the start epoch (QAT), diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 3c05fea5b..f4117e31d 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -11,7 +11,6 @@ get_layer_mappings_from_architecture, handle_mapping_resolution_errors, ) -from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.pytorch.module import get_layers, get_matching_layer @@ -53,7 +52,7 @@ class SmoothQuantMapping: balance_layers: List[Module] -class SmoothQuantModifier(Modifier, HooksMixin): +class SmoothQuantModifier(Modifier): """ Implements the SmoothQuant algorithm from https://arxiv.org/abs/2211.10438. This modifier performs a channel-wise smoothing of outliers in activations, making them From a5635a1032a059c4a70252eea9f1a36a7ea8aa4f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Nov 2024 12:43:37 -0500 Subject: [PATCH 094/285] merge Signed-off-by: Kyle Sayers --- examples/quantization_w4a16/vision_example.py | 15 +++------------ shubhra.py | 2 +- .../modifiers/quantization/gptq/base.py | 14 ++++---------- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/examples/quantization_w4a16/vision_example.py b/examples/quantization_w4a16/vision_example.py index f89ada21a..88fc79983 100644 --- a/examples/quantization_w4a16/vision_example.py +++ b/examples/quantization_w4a16/vision_example.py @@ -1,5 +1,5 @@ from datasets import load_dataset -from transformers import AutoProcessor +from transformers import AutoProcessor, MllamaForConditionalGeneration from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot @@ -7,12 +7,11 @@ # Select model and load it. MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" -model = SparseAutoModelForCausalLM.from_pretrained( +model = MllamaForConditionalGeneration.from_pretrained( MODEL_ID, device_map="cuda:0", torch_dtype="auto", ) -breakpoint() processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) # Select calibration dataset. @@ -61,7 +60,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=-1, dampening_frac=0.5) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], update_size=NUM_CALIBRATION_SAMPLES, dampening_frac=0.5) # Apply algorithms. oneshot( @@ -74,14 +73,6 @@ def tokenize(sample): trust_remote_code_model=True, ) -# Confirm generations of the quantized model look sane. -print("\n\n") -print("========== SAMPLE GENERATION ==============") -input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(processor.decode(output[0])) -print("==========================================\n\n") - # Save to disk compressed. SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" model.save_pretrained(SAVE_DIR, save_compressed=True) diff --git a/shubhra.py b/shubhra.py index 4996c8277..6cf911d8b 100644 --- a/shubhra.py +++ b/shubhra.py @@ -16,7 +16,7 @@ DATASET_ID = "lmms-lab/flickr30k" DATASET_SPLIT = "test[:128]" -NUM_CALIBRATION_SAMPLES = 1#128 +NUM_CALIBRATION_SAMPLES = 128 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index ab2daa1a3..a7d67b232 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -import math from compressed_tensors.quantization import ( QuantizationScheme, ) @@ -16,21 +15,17 @@ from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import accumulate_hessian, make_empty_hessian, quantize_weight from llmcompressor.modifiers.quantization.gptq.utils.partitioned_model import PartitionedModel from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier -from llmcompressor.modifiers.utils.hooks import HooksMixin, LayerCompressorMixin -from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException, run_calibration_forward -from llmcompressor.observers.base import Observer +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.transformers.finetune.data.data_helpers import ( create_batch_dataloader, ) -from llmcompressor.utils.fsdp.helpers import delete_offload_parameter, register_offload_parameter, update_offload_parameter +from llmcompressor.utils.fsdp.helpers import update_offload_parameter from llmcompressor.utils.helpers import ( align_module, calibration_forward_context, getattr_chain, ) -from compressed_tensors.quantization import ( - fake_quantize, -) from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active @@ -213,12 +208,11 @@ def hook(module: torch.nn.Module, args: Tuple[Any, ...]): raise EarlyStopException(None, None) # feed data - dataloader = create_batch_dataloader(dataloader, batch_size=1) with calibration_forward_context(state.model): targets = get_no_split_params(state.model) partitioned_model = PartitionedModel() partitioned_model.init_forward(state.model, targets) - partitioned_model.forward_data(dataloader, mask_padding=True) + partitioned_model.forward_data(state.data.calib, mask_padding=True) return True From 83ed409fb1de2762a3784e12916a029f925af207 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Nov 2024 13:50:35 -0500 Subject: [PATCH 095/285] small updates Signed-off-by: Kyle Sayers --- shubhra.py | 5 ++--- src/llmcompressor/modifiers/quantization/gptq/base.py | 5 ++++- .../modifiers/quantization/gptq/utils/partitioned_model.py | 1 + src/llmcompressor/modifiers/utils/hooks.py | 6 +----- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/shubhra.py b/shubhra.py index 6cf911d8b..4bb5409a5 100644 --- a/shubhra.py +++ b/shubhra.py @@ -2,14 +2,13 @@ from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot, wrap_hf_model_class +from llmcompressor.transformers import oneshot import os # Load model. #model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" model_id = "mgoin/pixtral-12b" -model_class = wrap_hf_model_class(LlavaForConditionalGeneration) -model = model_class.from_pretrained(model_id, device_map="auto", torch_dtype="auto", trust_remote_code=True, _attn_implementation="eager",) +model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto", _attn_implementation="eager",) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) print("Loading dataset") diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index a7d67b232..44a43daf2 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -201,11 +201,13 @@ def on_initialize(self, state: "State", **kwargs) -> bool: for name, module in state.model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: post_hook = partial(self.compress_module, name) - self.register_forward_hook(module, post_hook) + self.register_hook(module, post_hook, "forward") if "head" in name: def hook(module: torch.nn.Module, args: Tuple[Any, ...]): raise EarlyStopException(None, None) + + self.register_hook(module, hook, "forward_pre") # feed data with calibration_forward_context(state.model): @@ -235,6 +237,7 @@ def compress_module( name: str, module: torch.nn.Module, args: Tuple[torch.Tensor, ...], + _output: torch.Tensor, ): """ Quantize a module's weight according to the GPTQ algorithm diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 549a5f534..b55c1e223 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -180,6 +180,7 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> return super().is_leaf_module(module, module_qualified_name) self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) + #self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) # 2. identify target nodes all_target_nodes = get_target_nodes(self.graph, targets) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 44e9c13f9..721a26129 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -15,14 +15,11 @@ class HooksMixin(BaseModel): Mixin to manage hook registration, disabling, and removal. Modifiers should use `self.register_hook(module, hook, hook_type)` for hook registration and `self.remove_hooks()` for removal. - Modifiers which implement hooks should register them using `self.register_..._hook(module, hook)` rather than the usual `module.register_..._hook(hook)`. Modifiers should remove hooks with `self.remove_hooks()`. - Hooks can be applied to modules or parameters - Lifecycle: - modifier.register_forward_hook(module, hook) - with HooksMixin.disable_hooks(): model.forward() @@ -52,7 +49,6 @@ def register_hook( """ Registers a hook on a specified module/parameter with the option to disable it with HooksMixin.disable_hooks() - :param target: the module or parameter on which the hook should be registered :param hook: the hook to register :param hook_type: the type of hook to register corresponding to the @@ -79,4 +75,4 @@ def remove_hooks(self): for hook in self._hooks: hook.remove() - self._hooks = [] + self._hooks = [] \ No newline at end of file From d1042826707e8fb682d0c2ba298c71aec8c7bc6a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Nov 2024 18:10:25 -0500 Subject: [PATCH 096/285] WIP --- .../gptq/utils/partitioned_model.py | 103 +++++++++++++++++- .../quantization/quantization/base.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 2 + src/llmcompressor/pytorch/__init__.py | 72 ++++++------ 4 files changed, 136 insertions(+), 43 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index b55c1e223..27c4e6cda 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -1,7 +1,9 @@ import contextlib -from typing import Any, Callable, Dict, List, Set +import inspect +from typing import Any, Callable, Dict, List, Set, Tuple +import tqdm import torch from collections import deque from transformers import AutoModel @@ -12,6 +14,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException, apply_pad_mask_to_batch from llmcompressor.pytorch.utils.helpers import tensors_to_device +from llmcompressor.utils.helpers import calibration_forward_context def get_target_nodes(graph: GraphModule, targets: List[str]): @@ -144,6 +147,11 @@ def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): print([n for n in subgraph.nodes]) assert check_assumption(subgraph) + return subgraphs + + +def trace_consumed_names(subgraphs: List[Dict[str, Any]]): + # TODO: update consumed names as new partitions are appended # populate consumed_names according to when inputs are last used # in order to vacate the `intermediates` cache and save memory all_input_names = set().union(*(subgraph["input_names"] for subgraph in subgraphs)) @@ -155,22 +163,75 @@ def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): else: assert False - return subgraphs + +def make_fused_concrete_args(root: torch.nn.Module, dummy_inputs: Dict[str, Any]): + sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root) + + concrete_args = {} + + for param in sig.parameters.values(): + if param.name in dummy_inputs: + continue + if param.default is inspect.Parameter.empty: + raise ValueError(f"You need to specify a default value for the parameter {param.name}.") + + concrete_args.update( + { + p.name: p.default + for p in sig.parameters.values() + if (p.name not in dummy_inputs and p.name not in concrete_args) + } + ) + concrete_args.update(dummy_inputs) + + return concrete_args + + +def make_placeholders(model: torch.nn.Module, graph: GraphModule, dummy_inputs: Dict[str, Any]): + reversed_dictionary = {v: k for k, v in dummy_inputs.items()} + # TODO: this dictionary does not match tensors which have been deep copied + # in general it's pretty annoying, since tracer.create_args_for_root basically + # converts kwargs to args and therefore gets rid of any of the names. + + # maybe instead of caching by kwargs, we cache by arg tuples? Not sure + + # Note, maybe relevant: tracer.create_args_for_root converts kwargs to args using the forward function signature + + breakpoint() + for node in graph.graph.nodes: + if node.op == "get_attr" and "tensor_constant" in node.target: + name = reversed_dictionary[getattr(model, node.target)] + node.target = name + node.op = "placeholder" + + breakpoint() + class PartitionedModel: def __init__(self): self.hook_targets = [] - self.hook_target_nodes = [] self.graph = None self.subgraphs = [] self.model = None + def partition_graph(self, graph: GraphModule, inputs: Tuple[Any, ...]): + print("partition_graph") + breakpoint() + + partitions = topological_partition(graph, self.targets) + subgraphs = partition_graph(self.model, partitions) + self.subgraphs.extend(subgraphs) + #breakpoint() + + return graph.forward + def register_hook(self, func: Callable, targets: List[str]): self.hook_targets.append((func, targets)) def init_forward(self, model: torch.nn.Module, targets): self.model = model + self.targets = targets # 1. trace graph class CustomTracer(HFTracer): @@ -178,9 +239,36 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> if type(module).__name__ in targets: return True # Treat as leaf, skip tracing inside this module return super().is_leaf_module(module, module_qualified_name) + + def _proxy_placeholder(self, *args, **kwargs): + node = super()._proxy_placeholder(*args, **kwargs) + return node + + # create a dictionary which maps the arg name to their nodes so we can more + # easily create a dictionary from names to nodes so we can identify which ones should be lifted into placeholders + - self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) - #self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) + with HooksMixin.disable_hooks(), calibration_forward_context(model): + #compiled = torch.compile(model, backend=self.partition_graph) + #compiled(**model.dummy_inputs) + + #program = torch.export.export(model, tuple(), model.dummy_inputs, strict=False) + #program = torch.export.export(model, tuple(), {}, strict=False) # requires inputs + + #self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) + #self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) + + concrete_args = make_fused_concrete_args(model, model.dummy_inputs) + print(concrete_args) + tracer = CustomTracer() + graph: GraphModule = tracer.trace(model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) + self.graph = torch.fx.GraphModule(model, graph) + self.graph.config = model.config + self.graph.class_for_deserialization = model.__class__ + self.graph.device = model.device + make_placeholders(model, self.graph, model.dummy_inputs) + breakpoint() + # 2. identify target nodes all_target_nodes = get_target_nodes(self.graph, targets) @@ -189,11 +277,13 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> partitions: List[List[Node]] = topological_partition(self.graph, all_target_nodes) self.subgraphs: List[GraphModule] = partition_graph(model, partitions) + trace_consumed_names(self.subgraphs) + def forward_data( self, dataloader, mask_padding: bool = True, - run_twice: bool = True + run_twice: bool = False ): # TODO: give option to skip lm_head # 4. perform compression @@ -225,6 +315,7 @@ def forward_data( inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} inputs = tensors_to_device(inputs, model_device) + print(inputs) try: subgraph_output = forward_function(self.model, **inputs) except EarlyStopException: diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 67da67bc0..cf339374d 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -207,7 +207,7 @@ def _check_calibration_data(self, config: QuantizationConfig): def _apply_modifier_to_model(self, model: Module): modifier_as_config = self.create_init_config() # Add step to attach kv_cache to the model, if present within the config - apply_quantization_config(model, modifier_as_config) + #apply_quantization_config(model, modifier_as_config) model.apply(set_unset_kv_cache) return modifier_as_config diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 721a26129..76eabf35d 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -57,6 +57,8 @@ def register_hook( :param kwargs: keyword arguments to pass to register hook method """ + return None + @wraps(hook) def wrapped_hook(*args, **kwargs): if HooksMixin._HOOKS_DISABLED: diff --git a/src/llmcompressor/pytorch/__init__.py b/src/llmcompressor/pytorch/__init__.py index 66d4be1b4..869f83f04 100644 --- a/src/llmcompressor/pytorch/__init__.py +++ b/src/llmcompressor/pytorch/__init__.py @@ -7,39 +7,39 @@ from packaging import version -try: - import torch - - _PARSED_TORCH_VERSION = version.parse(torch.__version__) - - if _PARSED_TORCH_VERSION.major >= 2: - torch_compile_func = torch.compile - - def raise_torch_compile_warning(*args, **kwargs): - warnings.warn( - "torch.compile is not supported by llmcompressor for torch 2.0.x" - ) - return torch_compile_func(*args, **kwargs) - - torch.compile = raise_torch_compile_warning - - _BYPASS = bool(int(os.environ.get("NM_BYPASS_TORCH_VERSION", "0"))) - if _PARSED_TORCH_VERSION.major == 1 and _PARSED_TORCH_VERSION.minor in [10, 11]: - if not _BYPASS: - raise RuntimeError( - "llmcompressor does not support torch==1.10.* or 1.11.*. " - f"Found torch version {torch.__version__}.\n\n" - "To bypass this error, set environment variable " - "`NM_BYPASS_TORCH_VERSION` to '1'.\n\n" - "Bypassing may result in errors or " - "incorrect behavior, so set at your own risk." - ) - else: - warnings.warn( - "llmcompressor quantized onnx export does not work " - "with torch==1.10.* or 1.11.*" - ) -except ImportError: - pass - -# flake8: noqa +# try: +# import torch + +# _PARSED_TORCH_VERSION = version.parse(torch.__version__) + +# if _PARSED_TORCH_VERSION.major >= 2: +# torch_compile_func = torch.compile + +# def raise_torch_compile_warning(*args, **kwargs): +# warnings.warn( +# "torch.compile is not supported by llmcompressor for torch 2.0.x" +# ) +# return torch_compile_func(*args, **kwargs) + +# torch.compile = raise_torch_compile_warning + +# _BYPASS = bool(int(os.environ.get("NM_BYPASS_TORCH_VERSION", "0"))) +# if _PARSED_TORCH_VERSION.major == 1 and _PARSED_TORCH_VERSION.minor in [10, 11]: +# if not _BYPASS: +# raise RuntimeError( +# "llmcompressor does not support torch==1.10.* or 1.11.*. " +# f"Found torch version {torch.__version__}.\n\n" +# "To bypass this error, set environment variable " +# "`NM_BYPASS_TORCH_VERSION` to '1'.\n\n" +# "Bypassing may result in errors or " +# "incorrect behavior, so set at your own risk." +# ) +# else: +# warnings.warn( +# "llmcompressor quantized onnx export does not work " +# "with torch==1.10.* or 1.11.*" +# ) +# except ImportError: +# pass + +# # flake8: noqa From 236a47a4aadb663ba3534bf0f7465269b8423352 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 21 Nov 2024 17:30:10 -0500 Subject: [PATCH 097/285] WIP --- shubhra.py | 29 +++++- .../gptq/utils/partitioned_model.py | 93 ++++++++++++------- .../compressed_tensors_utils.py | 1 + 3 files changed, 87 insertions(+), 36 deletions(-) diff --git a/shubhra.py b/shubhra.py index 4bb5409a5..57ea9f644 100644 --- a/shubhra.py +++ b/shubhra.py @@ -1,5 +1,6 @@ +import torch from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModel from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot @@ -7,15 +8,17 @@ # Load model. #model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +#model = MllamaForConditionalGeneration.from_pretrained(model_id) model_id = "mgoin/pixtral-12b" -model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto", _attn_implementation="eager",) +model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", _attn_implementation="eager") +#model = AutoModel.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) print("Loading dataset") DATASET_ID = "lmms-lab/flickr30k" DATASET_SPLIT = "test[:128]" -NUM_CALIBRATION_SAMPLES = 128 +NUM_CALIBRATION_SAMPLES = 1 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -47,11 +50,27 @@ def preprocess(example): # Tokenize inputs. def tokenize(sample): - return processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt") + tmp = processor( + sample["image"], + sample["text"], + add_special_tokens=False, + return_tensors="pt" + ) + + # Remove batch dimension from each key + input_ids = tmp["input_ids"].squeeze(0) + #attention_mask = tmp["attention_mask"].squeeze(0) + #pixel_values = [tmp["pixel_values"][0][0].squeeze(0)] + + return { + "input_ids": torch.LongTensor(input_ids), + #"attention_mask": attention_mask, + #"pixel_values": pixel_values, + } + ds = ds.map(tokenize, remove_columns=ds.column_names) -print(ds) print("Setting up quantization params") # Configure the quantization algorithm and scheme. diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 27c4e6cda..1ccc57c13 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -187,8 +187,7 @@ def make_fused_concrete_args(root: torch.nn.Module, dummy_inputs: Dict[str, Any] return concrete_args -def make_placeholders(model: torch.nn.Module, graph: GraphModule, dummy_inputs: Dict[str, Any]): - reversed_dictionary = {v: k for k, v in dummy_inputs.items()} +def make_placeholders(tracer, model: torch.nn.Module, graph: GraphModule, dummy_inputs: Dict[str, Any]): # TODO: this dictionary does not match tensors which have been deep copied # in general it's pretty annoying, since tracer.create_args_for_root basically # converts kwargs to args and therefore gets rid of any of the names. @@ -197,14 +196,21 @@ def make_placeholders(model: torch.nn.Module, graph: GraphModule, dummy_inputs: # Note, maybe relevant: tracer.create_args_for_root converts kwargs to args using the forward function signature + # TODO: assumes that all inputs are tensors breakpoint() - for node in graph.graph.nodes: - if node.op == "get_attr" and "tensor_constant" in node.target: - name = reversed_dictionary[getattr(model, node.target)] - node.target = name - node.op = "placeholder" + for input_name, input_value in dummy_inputs.items(): + for tensor_value, name in tracer.tensor_attrs.items(): + if torch.allclose(input_value, tensor_value): + nodes = graph.graph.find_nodes(op="get_attr", target=name) + assert len(nodes) == 1 + node = nodes[0] + node.target = input_name + node.name = input_name + node.op = "placeholder" + break - breakpoint() + else: + raise ValueError() @@ -233,22 +239,26 @@ def init_forward(self, model: torch.nn.Module, targets): self.model = model self.targets = targets + def forward_data( + self, + dataloader, + mask_padding: bool = True, + run_twice: bool = False + ): + + #from pixtral_code import forward as compiled_forward + #compiled_forward(self.model, ) + # 1. trace graph + targets = self.targets class CustomTracer(HFTracer): def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> bool: if type(module).__name__ in targets: return True # Treat as leaf, skip tracing inside this module return super().is_leaf_module(module, module_qualified_name) - - def _proxy_placeholder(self, *args, **kwargs): - node = super()._proxy_placeholder(*args, **kwargs) - return node - - # create a dictionary which maps the arg name to their nodes so we can more - # easily create a dictionary from names to nodes so we can identify which ones should be lifted into placeholders - with HooksMixin.disable_hooks(), calibration_forward_context(model): + with HooksMixin.disable_hooks(), calibration_forward_context(self.model): #compiled = torch.compile(model, backend=self.partition_graph) #compiled(**model.dummy_inputs) @@ -258,33 +268,53 @@ def _proxy_placeholder(self, *args, **kwargs): #self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) #self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) - concrete_args = make_fused_concrete_args(model, model.dummy_inputs) + #sample_input = next(iter(dataloader)) + sample_input = self.model.dummy_inputs + breakpoint() + + concrete_args = make_fused_concrete_args(self.model, sample_input) print(concrete_args) tracer = CustomTracer() - graph: GraphModule = tracer.trace(model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) - self.graph = torch.fx.GraphModule(model, graph) - self.graph.config = model.config - self.graph.class_for_deserialization = model.__class__ - self.graph.device = model.device - make_placeholders(model, self.graph, model.dummy_inputs) + graph: GraphModule = tracer.trace(self.model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) + self.graph = torch.fx.GraphModule(self.model, graph) + self.graph.config = self.model.config + self.graph.class_for_deserialization = self.model.__class__ + self.graph.device = self.model.device + breakpoint() + make_placeholders(tracer, self.model, self.graph, sample_input) breakpoint() # 2. identify target nodes - all_target_nodes = get_target_nodes(self.graph, targets) + all_target_nodes = get_target_nodes(self.graph, self.targets) # 3. cut into partitions along target nodes partitions: List[List[Node]] = topological_partition(self.graph, all_target_nodes) - self.subgraphs: List[GraphModule] = partition_graph(model, partitions) + self.subgraphs: List[GraphModule] = partition_graph(self.model, partitions) trace_consumed_names(self.subgraphs) - def forward_data( - self, - dataloader, - mask_padding: bool = True, - run_twice: bool = False - ): + + + + + + + + + + + + + + + + + + + + + # TODO: give option to skip lm_head # 4. perform compression model_device = next(self.model.parameters()).device @@ -316,6 +346,7 @@ def forward_data( inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} inputs = tensors_to_device(inputs, model_device) print(inputs) + breakpoint() try: subgraph_output = forward_function(self.model, **inputs) except EarlyStopException: diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 001b43b0b..e6f10c80e 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -226,6 +226,7 @@ def patch_tied_tensors_bug(model: torch.nn.Module): :param model: model to fix """ + return if ( hasattr(model.config, "tie_word_embeddings") and not model.config.tie_word_embeddings From 188896e2ba6f9a1d5667e5fa0bc0b79af934df8a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 21 Nov 2024 18:00:32 -0500 Subject: [PATCH 098/285] able to run without hooks --- shubhra.py | 10 +++---- .../gptq/utils/partitioned_model.py | 28 ++----------------- 2 files changed, 7 insertions(+), 31 deletions(-) diff --git a/shubhra.py b/shubhra.py index 57ea9f644..4f549703c 100644 --- a/shubhra.py +++ b/shubhra.py @@ -10,15 +10,15 @@ #model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" #model = MllamaForConditionalGeneration.from_pretrained(model_id) model_id = "mgoin/pixtral-12b" -model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", _attn_implementation="eager") +model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") #model = AutoModel.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) print("Loading dataset") DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:128]" +DATASET_SPLIT = "test[:512]" -NUM_CALIBRATION_SAMPLES = 1 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -59,12 +59,12 @@ def tokenize(sample): # Remove batch dimension from each key input_ids = tmp["input_ids"].squeeze(0) - #attention_mask = tmp["attention_mask"].squeeze(0) + attention_mask = tmp["attention_mask"].squeeze(0) #pixel_values = [tmp["pixel_values"][0][0].squeeze(0)] return { "input_ids": torch.LongTensor(input_ids), - #"attention_mask": attention_mask, + "attention_mask": attention_mask, #"pixel_values": pixel_values, } diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 1ccc57c13..3000c530b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -10,6 +10,7 @@ from torch.fx import GraphModule, Graph, Node from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils.fx import symbolic_trace, HFTracer +from accelerate.hooks import remove_hook_from_module from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException, apply_pad_mask_to_batch @@ -197,7 +198,6 @@ def make_placeholders(tracer, model: torch.nn.Module, graph: GraphModule, dummy_ # Note, maybe relevant: tracer.create_args_for_root converts kwargs to args using the forward function signature # TODO: assumes that all inputs are tensors - breakpoint() for input_name, input_value in dummy_inputs.items(): for tensor_value, name in tracer.tensor_attrs.items(): if torch.allclose(input_value, tensor_value): @@ -223,12 +223,10 @@ def __init__(self): def partition_graph(self, graph: GraphModule, inputs: Tuple[Any, ...]): print("partition_graph") - breakpoint() partitions = topological_partition(graph, self.targets) subgraphs = partition_graph(self.model, partitions) self.subgraphs.extend(subgraphs) - #breakpoint() return graph.forward @@ -270,19 +268,17 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> #sample_input = next(iter(dataloader)) sample_input = self.model.dummy_inputs - breakpoint() concrete_args = make_fused_concrete_args(self.model, sample_input) print(concrete_args) tracer = CustomTracer() + remove_hook_from_module(self.model, recurse=True) graph: GraphModule = tracer.trace(self.model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) self.graph = torch.fx.GraphModule(self.model, graph) self.graph.config = self.model.config self.graph.class_for_deserialization = self.model.__class__ self.graph.device = self.model.device - breakpoint() make_placeholders(tracer, self.model, self.graph, sample_input) - breakpoint() # 2. identify target nodes @@ -294,24 +290,6 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> trace_consumed_names(self.subgraphs) - - - - - - - - - - - - - - - - - - @@ -345,8 +323,6 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} inputs = tensors_to_device(inputs, model_device) - print(inputs) - breakpoint() try: subgraph_output = forward_function(self.model, **inputs) except EarlyStopException: From 8ef9c23c2b064f517a87da2208073f922fcf45f6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 21 Nov 2024 18:26:06 -0500 Subject: [PATCH 099/285] issue with different sizes --- .../modifiers/quantization/gptq/base.py | 7 ++-- .../gptq/utils/partitioned_model.py | 32 ++++++------------- src/llmcompressor/modifiers/utils/hooks.py | 3 -- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 44a43daf2..a57bfd939 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -196,6 +196,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") + + targets = get_no_split_params(state.model) + partitioned_model = PartitionedModel() + partitioned_model.init_forward(state.model, targets, next(iter(state.data.calib))) # register hooks for name, module in state.model.named_modules(): @@ -211,9 +215,6 @@ def hook(module: torch.nn.Module, args: Tuple[Any, ...]): # feed data with calibration_forward_context(state.model): - targets = get_no_split_params(state.model) - partitioned_model = PartitionedModel() - partitioned_model.init_forward(state.model, targets) partitioned_model.forward_data(state.data.calib, mask_padding=True) return True diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 3000c530b..9b8b84459 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -216,7 +216,6 @@ def make_placeholders(tracer, model: torch.nn.Module, graph: GraphModule, dummy_ class PartitionedModel: def __init__(self): - self.hook_targets = [] self.graph = None self.subgraphs = [] self.model = None @@ -230,23 +229,10 @@ def partition_graph(self, graph: GraphModule, inputs: Tuple[Any, ...]): return graph.forward - def register_hook(self, func: Callable, targets: List[str]): - self.hook_targets.append((func, targets)) - - def init_forward(self, model: torch.nn.Module, targets): + def init_forward(self, model: torch.nn.Module, targets: List[str], dummy_input: Dict[str, Any]): self.model = model self.targets = targets - def forward_data( - self, - dataloader, - mask_padding: bool = True, - run_twice: bool = False - ): - - #from pixtral_code import forward as compiled_forward - #compiled_forward(self.model, ) - # 1. trace graph targets = self.targets class CustomTracer(HFTracer): @@ -267,9 +253,7 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> #self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) #sample_input = next(iter(dataloader)) - sample_input = self.model.dummy_inputs - - concrete_args = make_fused_concrete_args(self.model, sample_input) + concrete_args = make_fused_concrete_args(self.model, dummy_input) print(concrete_args) tracer = CustomTracer() remove_hook_from_module(self.model, recurse=True) @@ -278,7 +262,7 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> self.graph.config = self.model.config self.graph.class_for_deserialization = self.model.__class__ self.graph.device = self.model.device - make_placeholders(tracer, self.model, self.graph, sample_input) + make_placeholders(tracer, self.model, self.graph, dummy_input) # 2. identify target nodes @@ -290,9 +274,12 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> trace_consumed_names(self.subgraphs) - - - + def forward_data( + self, + dataloader, + mask_padding: bool = True, + run_twice: bool = False + ): # TODO: give option to skip lm_head # 4. perform compression model_device = next(self.model.parameters()).device @@ -323,6 +310,7 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} inputs = tensors_to_device(inputs, model_device) + breakpoint() try: subgraph_output = forward_function(self.model, **inputs) except EarlyStopException: diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 76eabf35d..00985bc22 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -56,9 +56,6 @@ def register_hook( Ex. "forward", "forward_pre", "full_backward", "state_dict_post", "" :param kwargs: keyword arguments to pass to register hook method """ - - return None - @wraps(hook) def wrapped_hook(*args, **kwargs): if HooksMixin._HOOKS_DISABLED: From 1362ca2c650edb53543b283e074812aa3500e619 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 21 Nov 2024 18:50:30 -0500 Subject: [PATCH 100/285] able to run through pixtral without issue and using real proxy tensors. Requires patching modeling_llava --- .../gptq/utils/partitioned_model.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 9b8b84459..7abcd6255 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -210,7 +210,7 @@ def make_placeholders(tracer, model: torch.nn.Module, graph: GraphModule, dummy_ break else: - raise ValueError() + breakpoint() @@ -240,6 +240,15 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> if type(module).__name__ in targets: return True # Treat as leaf, skip tracing inside this module return super().is_leaf_module(module, module_qualified_name) + + def to_bool(self, obj: 'Proxy') -> bool: + """Called when a proxy object is being converted to a boolean, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return a value. + """ + breakpoint() + return True with HooksMixin.disable_hooks(), calibration_forward_context(self.model): @@ -257,7 +266,10 @@ def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> print(concrete_args) tracer = CustomTracer() remove_hook_from_module(self.model, recurse=True) - graph: GraphModule = tracer.trace(self.model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) + #model.to("cuda:0") + #graph: GraphModule = tracer.trace(self.model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) + concrete_args = make_fused_concrete_args(self.model, {}) + graph: GraphModule = tracer.trace(self.model, dummy_inputs=dummy_input) self.graph = torch.fx.GraphModule(self.model, graph) self.graph.config = self.model.config self.graph.class_for_deserialization = self.model.__class__ @@ -310,7 +322,6 @@ def forward_data( inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} inputs = tensors_to_device(inputs, model_device) - breakpoint() try: subgraph_output = forward_function(self.model, **inputs) except EarlyStopException: From 0539df70ea22d7a9944c866d2eef322537ec845b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Nov 2024 16:21:37 +0000 Subject: [PATCH 101/285] nits Signed-off-by: Kyle Sayers --- .../modifiers/quantization/quantization/base.py | 8 ++++++-- src/llmcompressor/modifiers/utils/hooks.py | 3 ++- tests/llmcompressor/modifiers/utils/test_hooks.py | 2 ++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index c3c0e732f..9b4516b52 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -289,8 +289,12 @@ def register_calibration_hooks(self, module: Module): if output_quant: # hooks for attn modules if running kv_cache quant if is_attention_module_: - pre_hook = calibrate_kv_cache_input_hook - self.register_hook(module, pre_hook, "forward_pre", with_kwargs=True) + self.register_hook( + module, + calibrate_kv_cache_input_hook, + "forward_pre", + with_kwargs=True, + ) self.register_hook(module, calibrate_kv_cache_output_hook, "forward") diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 44e9c13f9..bb1755519 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -68,7 +68,8 @@ def wrapped_hook(*args, **kwargs): return hook(*args, **kwargs) - handle = getattr(target, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) + register_function = getattr(target, f"register_{hook_type}_hook") + handle = register_function(wrapped_hook, **kwargs) self._hooks.append(handle) logger.debug(f"{self} added {handle}") diff --git a/tests/llmcompressor/modifiers/utils/test_hooks.py b/tests/llmcompressor/modifiers/utils/test_hooks.py index 79ab3a1b4..5c4fc5891 100644 --- a/tests/llmcompressor/modifiers/utils/test_hooks.py +++ b/tests/llmcompressor/modifiers/utils/test_hooks.py @@ -4,6 +4,8 @@ class DummyModel(torch.nn.Module): + """Dummy Model for testing hooks""" + def __init__(self): super(DummyModel, self).__init__() From ed96ee4dc026506d4482041b4190a8ef6fef5cda Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Nov 2024 17:36:55 +0000 Subject: [PATCH 102/285] fix all variable --- src/llmcompressor/utils/metric_logging.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index 493559553..424cb3b0b 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -4,10 +4,8 @@ import torch from loguru import logger from torch.nn import Module - -__all__ = ["CompressionLogger"] - -__all__ = ["get_GPU_memory_usage", "get_layer_size_mb"] +g +__all__ = ["get_GPU_memory_usage", "get_layer_size_mb", "CompressionLogger"] def get_GPU_memory_usage() -> List[Tuple]: From 5f2671178b2218c4ac8a631f294575f7b1d66f61 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Nov 2024 19:36:35 +0000 Subject: [PATCH 103/285] tmp --- examples/quantization_w4a16/llama3_example.py | 13 +- .../quantization_w4a16/vision2_example.py | 10 +- examples/quantization_w4a16/vision_example.py | 24 +++- src/llmcompressor/data_pipelines/peicewise.py | 0 .../modifiers/quantization/gptq/base.py | 32 +++-- .../quantization/gptq/utils/gptq_quantize.py | 27 ++++- .../gptq/utils/partitioned_model.py | 112 ++++++++++-------- .../quantization/quantization/base.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 3 +- .../modifiers/utils/pytorch_helpers.py | 2 +- src/llmcompressor/pytorch/utils/helpers.py | 4 +- .../transformers/finetune/data/base.py | 4 +- .../transformers/finetune/text_generation.py | 2 +- src/llmcompressor/utils/fsdp/helpers.py | 57 +++++---- src/llmcompressor/utils/metric_logging.py | 2 +- .../modifiers/utils/test_hooks.py | 1 + 16 files changed, 190 insertions(+), 105 deletions(-) create mode 100644 src/llmcompressor/data_pipelines/peicewise.py diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index ba078ac1c..ebb6f6934 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -6,7 +6,7 @@ from llmcompressor.transformers import oneshot # Select model and load it. -#MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained( @@ -23,7 +23,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 285 #2048 +NUM_CALIBRATION_SAMPLES = 285 # 2048 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -58,7 +58,14 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], update_size=NUM_CALIBRATION_SAMPLES, dampening_frac=0.5) +recipe = GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=["lm_head"], + update_size=NUM_CALIBRATION_SAMPLES, + dampening_frac=0.5, + actorder="dynamic", +) # Apply algorithms. oneshot( diff --git a/examples/quantization_w4a16/vision2_example.py b/examples/quantization_w4a16/vision2_example.py index 1f57bb9f9..e3d25fbcd 100644 --- a/examples/quantization_w4a16/vision2_example.py +++ b/examples/quantization_w4a16/vision2_example.py @@ -20,7 +20,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 160 #2048 +NUM_CALIBRATION_SAMPLES = 160 # 2048 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -56,7 +56,13 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=1, dampening_frac=0.5) +recipe = GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=["lm_head"], + batch_size=1, + dampening_frac=0.5, +) # Apply algorithms. oneshot( diff --git a/examples/quantization_w4a16/vision_example.py b/examples/quantization_w4a16/vision_example.py index 88fc79983..2806eec92 100644 --- a/examples/quantization_w4a16/vision_example.py +++ b/examples/quantization_w4a16/vision_example.py @@ -20,7 +20,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 165 #2048 +NUM_CALIBRATION_SAMPLES = 165 # 2048 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -32,11 +32,11 @@ def preprocess(example): messages = [ [ { - "role": "user", + "role": "user", "content": [ {"type": "image"}, - {"type": "text", "text": "What does the image show?"} - ] + {"type": "text", "text": "What does the image show?"}, + ], } ], ] @@ -53,14 +53,26 @@ def preprocess(example): # Tokenize inputs. def tokenize(sample): - return processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt", max_length=MAX_SEQUENCE_LENGTH) + return processor( + sample["image"], + sample["text"], + add_special_tokens=False, + return_tensors="pt", + max_length=MAX_SEQUENCE_LENGTH, + ) ds = ds.map(tokenize, remove_columns=ds.column_names) # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], update_size=NUM_CALIBRATION_SAMPLES, dampening_frac=0.5) +recipe = GPTQModifier( + targets="Linear", + scheme="W4A16", + ignore=["lm_head"], + update_size=NUM_CALIBRATION_SAMPLES, + dampening_frac=0.5, +) # Apply algorithms. oneshot( diff --git a/src/llmcompressor/data_pipelines/peicewise.py b/src/llmcompressor/data_pipelines/peicewise.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index a57bfd939..b018cbd45 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,19 +1,23 @@ -from functools import partial import warnings +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import ( - QuantizationScheme, -) -from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization +from compressed_tensors.quantization import QuantizationScheme from loguru import logger from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory -from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import accumulate_hessian, make_empty_hessian, quantize_weight -from llmcompressor.modifiers.quantization.gptq.utils.partitioned_model import PartitionedModel +from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import ( + accumulate_hessian, + make_empty_hessian, + quantize_weight, +) +from llmcompressor.modifiers.quantization.gptq.utils.partitioned_model import ( + PartitionedModel, +) from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException @@ -26,7 +30,6 @@ calibration_forward_context, getattr_chain, ) - from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active @@ -116,7 +119,9 @@ class GPTQModifier(Modifier, HooksMixin): disable_quantization_observer_epoch: Optional[float] = None _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() - _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=lambda: {}) + _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr( + default_factory=lambda: {} + ) _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=lambda: {}) @field_validator("sequential_update", mode="before") @@ -196,10 +201,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - + targets = get_no_split_params(state.model) partitioned_model = PartitionedModel() - partitioned_model.init_forward(state.model, targets, next(iter(state.data.calib))) + partitioned_model.init_forward( + state.model, targets, next(iter(state.data.calib)) + ) # register hooks for name, module in state.model.named_modules(): @@ -208,9 +215,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.register_hook(module, post_hook, "forward") if "head" in name: + def hook(module: torch.nn.Module, args: Tuple[Any, ...]): raise EarlyStopException(None, None) - + self.register_hook(module, hook, "forward_pre") # feed data diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 384d9fc8e..f01544330 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,8 +1,7 @@ import math from copy import copy -from typing import Tuple, Union, Optional, Type +from typing import Optional, Tuple, Type, Union -from llmcompressor.observers.base import Observer import torch import transformers from compressed_tensors.quantization import ( @@ -13,6 +12,7 @@ ) from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +from llmcompressor.observers.base import Observer from llmcompressor.pytorch.utils.helpers import tensor_sparsity GPTQ_PRECISION = torch.float32 @@ -21,10 +21,17 @@ def make_empty_hessian(module: torch.nn.Module): weight = module.weight num_columns = weight.shape[1] - return torch.zeros((num_columns, num_columns), device=weight.device, dtype=GPTQ_PRECISION) + return torch.zeros( + (num_columns, num_columns), device=weight.device, dtype=GPTQ_PRECISION + ) -def accumulate_hessian(inp: torch.Tensor, module_class: Type[torch.nn.Module], H: Optional[torch.Tensor] = None, num_samples: int = 1) -> Tuple[torch.Tensor, int]: +def accumulate_hessian( + inp: torch.Tensor, + module_class: Type[torch.nn.Module], + H: Optional[torch.Tensor] = None, + num_samples: int = 1, +) -> Tuple[torch.Tensor, int]: inp = inp.to(device=H.device) if len(inp.shape) == 2: inp = inp.unsqueeze(0) @@ -47,7 +54,9 @@ def accumulate_hessian(inp: torch.Tensor, module_class: Type[torch.nn.Module], H return H, num_samples -def compute_hessian(inp: torch.Tensor, module_class: Type[torch.nn.Module], device) -> torch.Tensor: +def compute_hessian( + inp: torch.Tensor, module_class: Type[torch.nn.Module], device +) -> torch.Tensor: """ Calculate the hessian with respect to the module inputs @@ -284,7 +293,13 @@ def quantize_weight( W = W.reshape(final_shape).to(final_dtype) loss = torch.sum(losses).item() - return loss, W, scale.to(dtype=final_dtype), zero_point.to(dtype=quant_args.pytorch_dtype()), g_idx + return ( + loss, + W, + scale.to(dtype=final_dtype), + zero_point.to(dtype=quant_args.pytorch_dtype()), + g_idx, + ) def _apply_activation_ordering( diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 7abcd6255..ce3cfaac0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -1,19 +1,21 @@ - import contextlib import inspect +from collections import deque from typing import Any, Callable, Dict, List, Set, Tuple -import tqdm import torch -from collections import deque +import tqdm +from accelerate.hooks import remove_hook_from_module +from torch.fx import Graph, GraphModule, Node from transformers import AutoModel -from torch.fx import GraphModule, Graph, Node from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.utils.fx import symbolic_trace, HFTracer -from accelerate.hooks import remove_hook_from_module +from transformers.utils.fx import HFTracer, symbolic_trace from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException, apply_pad_mask_to_batch +from llmcompressor.modifiers.utils.pytorch_helpers import ( + EarlyStopException, + apply_pad_mask_to_batch, +) from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context @@ -22,8 +24,8 @@ def get_target_nodes(graph: GraphModule, targets: List[str]): target_nodes = [] for node in graph.graph.nodes: if ( - node.op == "call_module" and - type(graph.get_submodule(node.target)).__name__ in targets + node.op == "call_module" + and type(graph.get_submodule(node.target)).__name__ in targets ): target_nodes.append(node) @@ -40,16 +42,17 @@ def check_assumption(graph: Graph) -> bool: if node not in input_node.users: return False - if ( - len(node.users) != len(set(node.users)) or - len(node.all_input_nodes) != len(set(node.all_input_nodes)) + if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len( + set(node.all_input_nodes) ): return False return True -def topological_partition(graph: GraphModule, target_nodes: Set[Node]) -> List[List[Node]]: +def topological_partition( + graph: GraphModule, target_nodes: Set[Node] +) -> List[List[Node]]: # use list representation to maintain topological sorting assert check_assumption(graph.graph) @@ -115,7 +118,7 @@ def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): new_input_nodes = { input_node for node in partition_nodes - #if node.op != "get_attr" + # if node.op != "get_attr" for input_node in node.all_input_nodes if input_node not in partition_nodes and input_node.op } @@ -138,12 +141,14 @@ def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): # Save the subgraph for this partition subgraph.lint() input_names = [node.name for node in subgraph.nodes if node.op == "placeholder"] - subgraphs.append({ - "graph": subgraph, - "code": subgraph.python_code("self"), - "input_names": input_names, - "consumed_names": [], - }) + subgraphs.append( + { + "graph": subgraph, + "code": subgraph.python_code("self"), + "input_names": input_names, + "consumed_names": [], + } + ) print([n for n in subgraph.nodes]) assert check_assumption(subgraph) @@ -174,8 +179,10 @@ def make_fused_concrete_args(root: torch.nn.Module, dummy_inputs: Dict[str, Any] if param.name in dummy_inputs: continue if param.default is inspect.Parameter.empty: - raise ValueError(f"You need to specify a default value for the parameter {param.name}.") - + raise ValueError( + f"You need to specify a default value for the parameter {param.name}." + ) + concrete_args.update( { p.name: p.default @@ -188,7 +195,9 @@ def make_fused_concrete_args(root: torch.nn.Module, dummy_inputs: Dict[str, Any] return concrete_args -def make_placeholders(tracer, model: torch.nn.Module, graph: GraphModule, dummy_inputs: Dict[str, Any]): +def make_placeholders( + tracer, model: torch.nn.Module, graph: GraphModule, dummy_inputs: Dict[str, Any] +): # TODO: this dictionary does not match tensors which have been deep copied # in general it's pretty annoying, since tracer.create_args_for_root basically # converts kwargs to args and therefore gets rid of any of the names. @@ -213,7 +222,6 @@ def make_placeholders(tracer, model: torch.nn.Module, graph: GraphModule, dummy_ breakpoint() - class PartitionedModel: def __init__(self): self.graph = None @@ -229,19 +237,24 @@ def partition_graph(self, graph: GraphModule, inputs: Tuple[Any, ...]): return graph.forward - def init_forward(self, model: torch.nn.Module, targets: List[str], dummy_input: Dict[str, Any]): + def init_forward( + self, model: torch.nn.Module, targets: List[str], dummy_input: Dict[str, Any] + ): self.model = model self.targets = targets # 1. trace graph targets = self.targets + class CustomTracer(HFTracer): - def is_leaf_module(self, module: torch.nn.Module, module_qualified_name: str) -> bool: + def is_leaf_module( + self, module: torch.nn.Module, module_qualified_name: str + ) -> bool: if type(module).__name__ in targets: return True # Treat as leaf, skip tracing inside this module return super().is_leaf_module(module, module_qualified_name) - - def to_bool(self, obj: 'Proxy') -> bool: + + def to_bool(self, obj: "Proxy") -> bool: """Called when a proxy object is being converted to a boolean, such as when used in control flow. Normally we don't know what to do because we don't know the value of the proxy, but a custom tracer can attach more @@ -249,25 +262,24 @@ def to_bool(self, obj: 'Proxy') -> bool: """ breakpoint() return True - - + with HooksMixin.disable_hooks(), calibration_forward_context(self.model): - #compiled = torch.compile(model, backend=self.partition_graph) - #compiled(**model.dummy_inputs) + # compiled = torch.compile(model, backend=self.partition_graph) + # compiled(**model.dummy_inputs) - #program = torch.export.export(model, tuple(), model.dummy_inputs, strict=False) - #program = torch.export.export(model, tuple(), {}, strict=False) # requires inputs + # program = torch.export.export(model, tuple(), model.dummy_inputs, strict=False) + # program = torch.export.export(model, tuple(), {}, strict=False) # requires inputs - #self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) - #self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) + # self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) + # self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) - #sample_input = next(iter(dataloader)) + # sample_input = next(iter(dataloader)) concrete_args = make_fused_concrete_args(self.model, dummy_input) print(concrete_args) tracer = CustomTracer() remove_hook_from_module(self.model, recurse=True) - #model.to("cuda:0") - #graph: GraphModule = tracer.trace(self.model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) + # model.to("cuda:0") + # graph: GraphModule = tracer.trace(self.model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) concrete_args = make_fused_concrete_args(self.model, {}) graph: GraphModule = tracer.trace(self.model, dummy_inputs=dummy_input) self.graph = torch.fx.GraphModule(self.model, graph) @@ -276,21 +288,19 @@ def to_bool(self, obj: 'Proxy') -> bool: self.graph.device = self.model.device make_placeholders(tracer, self.model, self.graph, dummy_input) - # 2. identify target nodes all_target_nodes = get_target_nodes(self.graph, self.targets) # 3. cut into partitions along target nodes - partitions: List[List[Node]] = topological_partition(self.graph, all_target_nodes) + partitions: List[List[Node]] = topological_partition( + self.graph, all_target_nodes + ) self.subgraphs: List[GraphModule] = partition_graph(self.model, partitions) trace_consumed_names(self.subgraphs) def forward_data( - self, - dataloader, - mask_padding: bool = True, - run_twice: bool = False + self, dataloader, mask_padding: bool = True, run_twice: bool = False ): # TODO: give option to skip lm_head # 4. perform compression @@ -309,18 +319,24 @@ def forward_data( if run_twice: for batch_index in range(len(dataloader)): intermediates = batch_intermediates[batch_index] - inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + inputs = { + input_name: intermediates[input_name] + for input_name in subgraph["input_names"] + } inputs = tensors_to_device(inputs, model_device) try: forward_function(self.model, **inputs) except EarlyStopException: pass - + with HooksMixin.disable_hooks() if run_twice else contextlib.nullcontext(): for batch_index in range(len(dataloader)): intermediates = batch_intermediates[batch_index] - inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} + inputs = { + input_name: intermediates[input_name] + for input_name in subgraph["input_names"] + } inputs = tensors_to_device(inputs, model_device) try: subgraph_output = forward_function(self.model, **inputs) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 2e2c58946..ad787b9d6 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -209,7 +209,7 @@ def _check_calibration_data(self, config: QuantizationConfig): def _apply_modifier_to_model(self, model: Module): modifier_as_config = self.create_init_config() # Add step to attach kv_cache to the model, if present within the config - #apply_quantization_config(model, modifier_as_config) + # apply_quantization_config(model, modifier_as_config) model.apply(set_unset_kv_cache) return modifier_as_config diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index c17647677..2a7b562f9 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -56,6 +56,7 @@ def register_hook( Ex. "forward", "forward_pre", "full_backward", "state_dict_post", "" :param kwargs: keyword arguments to pass to register hook method """ + @wraps(hook) def wrapped_hook(*args, **kwargs): if HooksMixin._HOOKS_DISABLED: @@ -75,4 +76,4 @@ def remove_hooks(self): for hook in self._hooks: hook.remove() - self._hooks = [] \ No newline at end of file + self._hooks = [] diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 8c7b8b318..9aeccd059 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -100,7 +100,7 @@ def run_calibration_forward( except EarlyStopException as e: # model was stopped early, save last calculated output and # move on to next calibration sample - #intermediates.append((e.args, e.kwargs)) + # intermediates.append((e.args, e.kwargs)) pass # TODO: not ideal, figure out where we aren't freeing memory instead diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index db66793b4..2f9074f9b 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -304,7 +304,7 @@ def tensors_to_device( """ if isinstance(tensors, Tensor): return tensors.to(device) - + if isinstance(tensors, OrderedDict): return OrderedDict( [(key, tensors_to_device(tens, device)) for key, tens in tensors.items()] @@ -318,7 +318,7 @@ def tensors_to_device( if isinstance(tensors, Iterable): return [tensors_to_device(tens, device) for tens in tensors] - + return tensors diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 941306180..bf0829080 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -59,7 +59,9 @@ def __init__( # configure sequence length max_seq_length = data_args.max_seq_length - model_max_length = self.tokenizer.model_max_length if self.tokenizer else max_seq_length + model_max_length = ( + self.tokenizer.model_max_length if self.tokenizer else max_seq_length + ) if self.tokenizer and max_seq_length > model_max_length: logger.warning( f"The max_seq_length passed ({max_seq_length}) is larger than " diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 36b06974f..b76f09ae0 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -23,8 +23,8 @@ from loguru import logger from transformers import ( AutoConfig, - AutoProcessor, AutoModelForCausalLM, + AutoProcessor, AutoTokenizer, DefaultDataCollator, HfArgumentParser, diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index f2f902344..133e8de24 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,9 +1,9 @@ import contextlib -from functools import wraps import operator +import warnings +from functools import wraps from pathlib import Path from typing import Optional -import warnings from loguru import logger @@ -27,7 +27,12 @@ try: from accelerate.hooks import AlignDevicesHook - from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset, set_module_tensor_to_device + from accelerate.utils import ( + OffloadedWeightsLoader, + PrefixedDataset, + set_module_tensor_to_device, + ) + _has_accelerate = True except ImportError: _has_accelerate = False @@ -195,6 +200,7 @@ def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: return parent + # upstream candidate def has_offloaded_params(module: torch.nn.Module) -> bool: """ @@ -209,9 +215,9 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: `False` otherwise. """ return ( - hasattr(module, "_hf_hook") and - isinstance(module._hf_hook, AlignDevicesHook) and - module._hf_hook.offload + hasattr(module, "_hf_hook") + and isinstance(module._hf_hook, AlignDevicesHook) + and module._hf_hook.offload ) @@ -245,7 +251,7 @@ def get_execution_device(module: torch.nn.Module) -> torch.device: def _infer_offload_device(module: torch.nn.Module) -> torch.device: if not has_offloaded_params(module): raise ValueError("Cannot infer offload device from non-offloaded module") - + first_key = next(module._hf_hook.weights_map.keys(), None) if first_key is None: raise ValueError("Cannot infer offload device from empty weights map") @@ -253,6 +259,7 @@ def _infer_offload_device(module: torch.nn.Module) -> torch.device: prefix_dataset = module._hf_hook.weights_map.dataset return prefix_dataset[first_key].device + # depreciation candidate def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ @@ -296,13 +303,15 @@ def update_offload_parameter( param = getattr(module, name) if data is not None: if data.device == "meta": - raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") - + raise ValueError( + "Cannot copy data from meta device. Consider calling with align_module(module) context" + ) + if param.data.dtype != data.dtype: print(name) print((param.data.dtype, data.dtype)) warnings.warn("TODO") - + param.data.copy_(data) if has_offloaded_params(module): @@ -310,24 +319,29 @@ def update_offload_parameter( # for upstreaming, probably better to modify the weight map types so that they can be written to? if isinstance(weights_map, PrefixedDataset): - prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + prefix_dict = getattr_chain( + module, "module._hf_hook.weights_map.dataset", None + ) if prefix_dict is not None: prefix = module._hf_hook.weights_map.prefix key = f"{prefix}{name}" offload_device = ( - prefix_dict[key].device if key in prefix_dict - else offload_device if offload_device is not None + prefix_dict[key].device + if key in prefix_dict + else offload_device + if offload_device is not None else _infer_offload_device(module) ) prefix_dict[key] = param.data.to(device=offload_device) - + if isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() - + else: raise NotImplementedError() + # depreciation candidate def update_parameter_data( module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str @@ -339,7 +353,9 @@ def update_parameter_data( # upstream candidate @contextlib.contextmanager -def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): +def align_module( + module: torch.nn.Module, execution_device: Optional[torch.device] = None +): """ Moves a module's parameters to the specified execution device. @@ -386,7 +402,6 @@ def align_module(module: torch.nn.Module, execution_device: Optional[torch.devic yield - @contextlib.contextmanager def modify_offload_module( module: torch.nn.Module, @@ -427,9 +442,11 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): prefix = weights_map.prefix if dataset is not None: del dataset[f"{prefix}{name}"] - + elif isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() - + elif weights_map is not None: - raise NotImplementedError(f"Cannot delete parameter from weights_map of type {type(weights_map)}") \ No newline at end of file + raise NotImplementedError( + f"Cannot delete parameter from weights_map of type {type(weights_map)}" + ) diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index 424cb3b0b..826cf59ab 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -4,7 +4,7 @@ import torch from loguru import logger from torch.nn import Module -g + __all__ = ["get_GPU_memory_usage", "get_layer_size_mb", "CompressionLogger"] diff --git a/tests/llmcompressor/modifiers/utils/test_hooks.py b/tests/llmcompressor/modifiers/utils/test_hooks.py index b0c1b89ad..5c4fc5891 100644 --- a/tests/llmcompressor/modifiers/utils/test_hooks.py +++ b/tests/llmcompressor/modifiers/utils/test_hooks.py @@ -28,6 +28,7 @@ class DummyMod(HooksMixin): def hook(self, *args, **kwargs): self.hook_called = True + class ModA(DummyMod): pass From ebc2c4123f528ec80442192449aaf1840769d453 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Nov 2024 03:09:57 +0000 Subject: [PATCH 104/285] wip --- src/llmcompressor/data_pipelines/peicewise.py | 10 ++++++++++ .../quantization/gptq/utils/partitioned_model.py | 10 ---------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/llmcompressor/data_pipelines/peicewise.py b/src/llmcompressor/data_pipelines/peicewise.py index e69de29bb..f77af0869 100644 --- a/src/llmcompressor/data_pipelines/peicewise.py +++ b/src/llmcompressor/data_pipelines/peicewise.py @@ -0,0 +1,10 @@ +import torch + +from llmcompressor.utils.helpers import calibration_forward_context + + +def run_pipeline( + model: torch.nn.Module +): + with calibration_forward_context(model): + pass \ No newline at end of file diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index ce3cfaac0..d0b473ba0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -228,15 +228,6 @@ def __init__(self): self.subgraphs = [] self.model = None - def partition_graph(self, graph: GraphModule, inputs: Tuple[Any, ...]): - print("partition_graph") - - partitions = topological_partition(graph, self.targets) - subgraphs = partition_graph(self.model, partitions) - self.subgraphs.extend(subgraphs) - - return graph.forward - def init_forward( self, model: torch.nn.Module, targets: List[str], dummy_input: Dict[str, Any] ): @@ -260,7 +251,6 @@ def to_bool(self, obj: "Proxy") -> bool: we don't know the value of the proxy, but a custom tracer can attach more information to the graph node using create_node and can choose to return a value. """ - breakpoint() return True with HooksMixin.disable_hooks(), calibration_forward_context(self.model): From 922b407d11086d0403e8c94dea9bb2e143dcb75b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Nov 2024 03:13:36 +0000 Subject: [PATCH 105/285] wip --- .../gptq/utils/partitioned_model.py | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index d0b473ba0..0964857f3 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -7,8 +7,6 @@ import tqdm from accelerate.hooks import remove_hook_from_module from torch.fx import Graph, GraphModule, Node -from transformers import AutoModel -from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils.fx import HFTracer, symbolic_trace from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -254,29 +252,8 @@ def to_bool(self, obj: "Proxy") -> bool: return True with HooksMixin.disable_hooks(), calibration_forward_context(self.model): - # compiled = torch.compile(model, backend=self.partition_graph) - # compiled(**model.dummy_inputs) - - # program = torch.export.export(model, tuple(), model.dummy_inputs, strict=False) - # program = torch.export.export(model, tuple(), {}, strict=False) # requires inputs - - # self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) - # self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) - - # sample_input = next(iter(dataloader)) - concrete_args = make_fused_concrete_args(self.model, dummy_input) - print(concrete_args) - tracer = CustomTracer() - remove_hook_from_module(self.model, recurse=True) - # model.to("cuda:0") - # graph: GraphModule = tracer.trace(self.model, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False) - concrete_args = make_fused_concrete_args(self.model, {}) - graph: GraphModule = tracer.trace(self.model, dummy_inputs=dummy_input) - self.graph = torch.fx.GraphModule(self.model, graph) - self.graph.config = self.model.config - self.graph.class_for_deserialization = self.model.__class__ - self.graph.device = self.model.device - make_placeholders(tracer, self.model, self.graph, dummy_input) + self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) + #self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) # 2. identify target nodes all_target_nodes = get_target_nodes(self.graph, self.targets) From 0577f36dc835e11d8d0b98164123392510d6a6d6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Nov 2024 22:59:48 -0500 Subject: [PATCH 106/285] testing with lots of models --- llava.py | 108 ++++++++++++++++ mllama.py | 120 ++++++++++++++++++ qwen.py | 109 ++++++++++++++++ shubhra.py | 18 ++- .../gptq/utils/partitioned_model.py | 8 +- 5 files changed, 354 insertions(+), 9 deletions(-) create mode 100644 llava.py create mode 100644 mllama.py create mode 100644 qwen.py diff --git a/llava.py b/llava.py new file mode 100644 index 000000000..81a6a58d6 --- /dev/null +++ b/llava.py @@ -0,0 +1,108 @@ +import torch +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +import os + +# Load model. +model_id = "llava-hf/llava-1.5-7b-hf" + +model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +print("Loading dataset") +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:512]" + +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +print("Preprocessing samples") +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + tmp = processor( + sample["image"], + sample["text"], + add_special_tokens=False, + return_tensors="pt" + ) + + # Remove batch dimension from each key + input_ids = tmp["input_ids"].squeeze(0) + attention_mask = tmp["attention_mask"].squeeze(0) + pixel_values = [tmp["pixel_values"][0][0].squeeze(0)] + + return { + "input_ids": torch.LongTensor(input_ids), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +print("Setting up quantization params") +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] +#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] + +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), +] + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, +) + +#processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/mllama.py b/mllama.py new file mode 100644 index 000000000..55c2d10a3 --- /dev/null +++ b/mllama.py @@ -0,0 +1,120 @@ +import torch +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +import os + +# Load model. +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +#model = MllamaForConditionalGeneration.from_pretrained(model_id) +#model_id = "mgoin/pixtral-12b" +#model_id = "Qwen/Qwen2-VL-2B-Instruct" + +#model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") +#model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +#model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype="auto") +model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +print("Loading dataset") +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:512]" + +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +print("Preprocessing samples") +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + tmp = processor( + sample["image"], + sample["text"], + add_special_tokens=False, + return_tensors="pt" + ) + + # Remove batch dimension from each key + input_ids = tmp["input_ids"].squeeze(0) + attention_mask = tmp["attention_mask"].squeeze(0) + pixel_values = [tmp["pixel_values"][0][0].squeeze(0)] + aspect_ratio_ids = tmp["aspect_ratio_ids"].squeeze(0) + aspect_ratio_mask = tmp["aspect_ratio_mask"].squeeze(0) + cross_attention_mask = tmp["cross_attention_mask"].squeeze(0) + + return { + "input_ids": torch.LongTensor(input_ids), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "aspect_ratio_ids": aspect_ratio_ids, + "aspect_ratio_mask": aspect_ratio_mask, + "cross_attention_mask": cross_attention_mask, + } + + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +print("Setting up quantization params") +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] +#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] + +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), +] + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, +) + +#processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/qwen.py b/qwen.py new file mode 100644 index 000000000..f1ff43fe4 --- /dev/null +++ b/qwen.py @@ -0,0 +1,109 @@ +import torch +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +import os + +# Load model. +model_id = "Qwen/Qwen2-VL-2B-Instruct" +model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +print("Loading dataset") +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:512]" + +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +print("Preprocessing samples") +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + tmp = processor( + sample["image"], + sample["text"], + add_special_tokens=False, + return_tensors="pt" + ) + + # Remove batch dimension from each key + input_ids = tmp["input_ids"].squeeze(0) + attention_mask = tmp["attention_mask"].squeeze(0) + pixel_values = [tmp["pixel_values"][0][0].squeeze(0)] + image_grid_thw = tmp["image_grid_thw"].unsqueeze(0) + + return { + "input_ids": torch.LongTensor(input_ids), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + } + + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +print("Setting up quantization params") +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] +#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] + +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), +] + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, +) + +#processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/shubhra.py b/shubhra.py index 4f549703c..1a8dc088f 100644 --- a/shubhra.py +++ b/shubhra.py @@ -1,24 +1,28 @@ import torch from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModel +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot import os # Load model. -#model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" #model = MllamaForConditionalGeneration.from_pretrained(model_id) -model_id = "mgoin/pixtral-12b" -model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") -#model = AutoModel.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", torch_dtype="auto") +#model_id = "mgoin/pixtral-12b" +#model_id = "Qwen/Qwen2-VL-2B-Instruct" + +#model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") +#model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +#model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype="auto") +model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) print("Loading dataset") DATASET_ID = "lmms-lab/flickr30k" DATASET_SPLIT = "test[:512]" -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 1 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -83,7 +87,7 @@ def tokenize(sample): recipe = [ # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), ] save_name = model_id.split("/")[1] + "-W8A8" diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 0964857f3..b5e56ef0f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -252,8 +252,12 @@ def to_bool(self, obj: "Proxy") -> bool: return True with HooksMixin.disable_hooks(), calibration_forward_context(self.model): - self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) - #self.graph: GraphModule = CustomTracer().trace(model, dummy_inputs=model.dummy_inputs) + #self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) + self.graph: GraphModule = torch.fx.GraphModule(model, CustomTracer().trace(model, dummy_inputs=model.dummy_inputs, concrete_args={"use_cache": False})) + self.graph.config = model.config + self.graph.class_for_deserialization = model.__class__ + self.graph.device = model.device + self.graph: GraphModule # 2. identify target nodes all_target_nodes = get_target_nodes(self.graph, self.targets) From 3830696ba685f57dd8bd969f065584d35888cb06 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Nov 2024 04:35:37 +0000 Subject: [PATCH 107/285] preliminary data pipeline --- src/llmcompressor/data_pipelines/peicewise.py | 73 ++++++++++++++++++- .../gptq/utils/partitioned_model.py | 8 -- 2 files changed, 70 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/data_pipelines/peicewise.py b/src/llmcompressor/data_pipelines/peicewise.py index f77af0869..f7adc00d8 100644 --- a/src/llmcompressor/data_pipelines/peicewise.py +++ b/src/llmcompressor/data_pipelines/peicewise.py @@ -1,10 +1,77 @@ +import contextlib import torch -from llmcompressor.utils.helpers import calibration_forward_context +from datasets import Dataset + +from llmcompressor.core.session_functions import initialize +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException +from llmcompressor.recipe.recipe import Recipe +from llmcompressor.utils.helpers import calibration_forward_context, trace_subgraphs, get_targets, get_model_device, tensors_to_device, create_dataloader def run_pipeline( - model: torch.nn.Module + model: torch.nn.Module, + recipe: Recipe, + dataset: Dataset, + propagate_error: bool, ): + # trace subgraphs + targets = get_targets(recipe) + sample_input_names = next(iter(dataset)).keys() + subgraphs = trace_subgraphs(model, sample_input_names, targets) + + # apply recipe to model + initialize(recipe, model) + + # create dataloader + model_device = get_model_device(model) + dataloader = create_dataloader(dataset, batch_size=..., mask_padding=True, model_device=model_device) + with calibration_forward_context(model): - pass \ No newline at end of file + # prepare intermediates cache + batch_intermediates = list(iter(dataloader)) + batch_outputs = [None for _ in range(len(dataloader))] + + for subgraph_index, subgraph in enumerate(subgraphs): + # compile subgraph forward function + code = subgraph["code"] + exec(code.src, code.globals) + forward_function = code.globals.get("forward") + + if propagate_error: + # do an preliminary pass to trigger modifier hooks + for batch_index in range(len(dataloader)): + intermediates = batch_intermediates[batch_index] + inputs = { + input_name: intermediates[input_name] + for input_name in subgraph["input_names"] + } + inputs = tensors_to_device(inputs, model_device) + try: + forward_function(model, **inputs) + except EarlyStopException: + pass + + # if using propagate_error, then this pass does not trigger modifier hooks + # and is only used for capturing intermediates + # otherwise, this pass triggers modifier hooks and captures intermediates + with HooksMixin.disable_hooks() if propagate_error else contextlib.nullcontext(): + for batch_index in range(len(dataloader)): + intermediates = batch_intermediates[batch_index] + + inputs = { + input_name: intermediates[input_name] + for input_name in subgraph["input_names"] + } + inputs = tensors_to_device(inputs, model_device) + subgraph_output = forward_function(model, **inputs) + subgraph_output = tensors_to_device(subgraph_output, "cpu") + + for consumed_name in subgraph["consumed_names"]: + del intermediates[consumed_name] + + if subgraph_index < len(subgraphs) - 1: + intermediates.update(subgraph_output) + else: + batch_outputs[batch_index] = subgraph_output diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index b5e56ef0f..4605f11e2 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -243,14 +243,6 @@ def is_leaf_module( return True # Treat as leaf, skip tracing inside this module return super().is_leaf_module(module, module_qualified_name) - def to_bool(self, obj: "Proxy") -> bool: - """Called when a proxy object is being converted to a boolean, such as - when used in control flow. Normally we don't know what to do because - we don't know the value of the proxy, but a custom tracer can attach more - information to the graph node using create_node and can choose to return a value. - """ - return True - with HooksMixin.disable_hooks(), calibration_forward_context(self.model): #self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) self.graph: GraphModule = torch.fx.GraphModule(model, CustomTracer().trace(model, dummy_inputs=model.dummy_inputs, concrete_args={"use_cache": False})) From 1ecaa392a511393d3d7a87dcabac4a42c19d4462 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Nov 2024 05:10:02 +0000 Subject: [PATCH 108/285] WIP --- src/llmcompressor/data_pipelines/peicewise.py | 5 +---- src/llmcompressor/modifiers/quantization/gptq/base.py | 7 ------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/llmcompressor/data_pipelines/peicewise.py b/src/llmcompressor/data_pipelines/peicewise.py index f7adc00d8..cbd839cc0 100644 --- a/src/llmcompressor/data_pipelines/peicewise.py +++ b/src/llmcompressor/data_pipelines/peicewise.py @@ -48,10 +48,7 @@ def run_pipeline( for input_name in subgraph["input_names"] } inputs = tensors_to_device(inputs, model_device) - try: - forward_function(model, **inputs) - except EarlyStopException: - pass + forward_function(model, **inputs) # if using propagate_error, then this pass does not trigger modifier hooks # and is only used for capturing intermediates diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b018cbd45..900bded8b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -214,13 +214,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: post_hook = partial(self.compress_module, name) self.register_hook(module, post_hook, "forward") - if "head" in name: - - def hook(module: torch.nn.Module, args: Tuple[Any, ...]): - raise EarlyStopException(None, None) - - self.register_hook(module, hook, "forward_pre") - # feed data with calibration_forward_context(state.model): partitioned_model.forward_data(state.data.calib, mask_padding=True) From 9aa9679937b336d9ad375cff843508cd82864317 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Nov 2024 18:16:32 +0000 Subject: [PATCH 109/285] delete unnecessary files --- examples/quantization_w4a16/llama3_example.py | 18 +- .../quantization_w4a16/vision2_example.py | 89 ----- examples/quantization_w4a16/vision_example.py | 91 ------ graph_resuming.py | 308 ------------------ llava.py | 108 ------ mllama.py | 120 ------- shubhra.py | 114 ------- 7 files changed, 4 insertions(+), 844 deletions(-) delete mode 100644 examples/quantization_w4a16/vision2_example.py delete mode 100644 examples/quantization_w4a16/vision_example.py delete mode 100644 graph_resuming.py delete mode 100644 llava.py delete mode 100644 mllama.py delete mode 100644 shubhra.py diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index ebb6f6934..c08165299 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,4 +1,3 @@ -from accelerate import cpu_offload from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -6,15 +5,13 @@ from llmcompressor.transformers import oneshot # Select model and load it. -# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, - device_map="cuda:0", + device_map="auto", torch_dtype="auto", ) -# cpu_offload(model) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. @@ -23,7 +20,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 285 # 2048 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -58,14 +55,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier( - targets="Linear", - scheme="W4A16", - ignore=["lm_head"], - update_size=NUM_CALIBRATION_SAMPLES, - dampening_frac=0.5, - actorder="dynamic", -) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) # Apply algorithms. oneshot( diff --git a/examples/quantization_w4a16/vision2_example.py b/examples/quantization_w4a16/vision2_example.py deleted file mode 100644 index e3d25fbcd..000000000 --- a/examples/quantization_w4a16/vision2_example.py +++ /dev/null @@ -1,89 +0,0 @@ -from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot - -# Select model and load it. -MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" - -model = MllamaForConditionalGeneration.from_pretrained( - MODEL_ID, - device_map="cuda:0", - torch_dtype="auto", -) -processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" - -# Select number of samples. 512 samples is a good place to start. -# Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 160 # 2048 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - - -def preprocess(example): - return { - "text": processor.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return processor( - None, - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# Configure the quantization algorithm to run. -# * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier( - targets="Linear", - scheme="W4A16", - ignore=["lm_head"], - batch_size=1, - dampening_frac=0.5, -) - -# Apply algorithms. -oneshot( - model=model, - tokenizer=MODEL_ID, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, -) - -# Confirm generations of the quantized model look sane. -print("\n\n") -print("========== SAMPLE GENERATION ==============") -input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(processor.decode(output[0])) -print("==========================================\n\n") - -# Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" -model.save_pretrained(SAVE_DIR, save_compressed=True) -processor.save_pretrained(SAVE_DIR) diff --git a/examples/quantization_w4a16/vision_example.py b/examples/quantization_w4a16/vision_example.py deleted file mode 100644 index 2806eec92..000000000 --- a/examples/quantization_w4a16/vision_example.py +++ /dev/null @@ -1,91 +0,0 @@ -from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot - -# Select model and load it. -MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" - -model = MllamaForConditionalGeneration.from_pretrained( - MODEL_ID, - device_map="cuda:0", - torch_dtype="auto", -) -processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) - -# Select calibration dataset. -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:165]" - -# Select number of samples. 512 samples is a good place to start. -# Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 165 # 2048 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - - -def preprocess(example): - messages = [ - [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What does the image show?"}, - ], - } - ], - ] - return { - "text": processor.apply_chat_template( - messages, - add_generation_prompt=True, - ), - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return processor( - sample["image"], - sample["text"], - add_special_tokens=False, - return_tensors="pt", - max_length=MAX_SEQUENCE_LENGTH, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# Configure the quantization algorithm to run. -# * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier( - targets="Linear", - scheme="W4A16", - ignore=["lm_head"], - update_size=NUM_CALIBRATION_SAMPLES, - dampening_frac=0.5, -) - -# Apply algorithms. -oneshot( - model=model, - tokenizer=MODEL_ID, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, -) - -# Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" -model.save_pretrained(SAVE_DIR, save_compressed=True) -processor.save_pretrained(SAVE_DIR) diff --git a/graph_resuming.py b/graph_resuming.py deleted file mode 100644 index 14e0520ea..000000000 --- a/graph_resuming.py +++ /dev/null @@ -1,308 +0,0 @@ -from typing import Any, Callable, Dict, List, Set - -import torch -import inspect -from collections import deque -from transformers import AutoModel -from torch.fx import GraphModule, Graph, Node -from transformers.modeling_outputs import BaseModelOutputWithPast - - -class Model(torch.nn.Module): - def __init__(self, vocab_size=4096, d_model=128, n_heads=1, d_ff=256, dropout=0.1): - super(Model, self).__init__() - - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - assert d_model % n_heads == 0, "d_model must be divisible by n_heads" - - # Embedding layer - self.embedding = torch.nn.Embedding(vocab_size, d_model) - - # Linear transformations for queries, keys, and values - self.query_linear = torch.nn.Linear(d_model, d_model) - self.key_linear = torch.nn.Linear(d_model, d_model) - self.value_linear = torch.nn.Linear(d_model, d_model) - - # Output linear layer to combine heads - self.out_linear = torch.nn.Linear(d_model, d_model) - - # Position-wise feed-forward network - self.feed_forward = torch.nn.Sequential( - torch.nn.Linear(d_model, d_ff), - torch.nn.ReLU(), - torch.nn.Linear(d_ff, d_model) - ) - - # Layer normalization layers - self.norm1 = torch.nn.LayerNorm(d_model) - self.norm2 = torch.nn.LayerNorm(d_model) - - # Dropout layer - self.dropout = torch.nn.Dropout(dropout) - - def scaled_dot_product_attention(self, query, key, value): - # Calculate attention scores - scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) - attn_weights = torch.functional.F.softmax(scores, dim=-1) - output = torch.matmul(attn_weights, value) - return output - - def forward(self, input_ids): - # Apply embedding layer - x = self.embedding(input_ids) # (batch_size, seq_length, d_model) - - batch_size, seq_length, _ = x.size() - - # Linear projections - Q = self.query_linear(x) # (batch_size, seq_length, d_model) - K = self.key_linear(x) # (batch_size, seq_length, d_model) - V = self.value_linear(x) # (batch_size, seq_length, d_model) - - # Split Q, K, V into multiple heads - Q = Q.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) - K = K.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) - V = V.view(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_length, head_dim) - - # Scaled dot-product attention - attn_output = self.scaled_dot_product_attention(Q, K, V) # (batch_size, n_heads, seq_length, head_dim) - - # Concatenate heads - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model) - - # Apply final linear transformation - attn_output = self.out_linear(attn_output) - - # Add & Norm - x = x + self.dropout(attn_output) - x = self.norm1(x) - - # Feed-forward block - ff_output = self.feed_forward(x) - x = x + self.dropout(ff_output) - x = self.norm2(x) - - return BaseModelOutputWithPast(last_hidden_state=x) - - -def get_target_nodes(graph: GraphModule, targets: List[str]): - target_nodes = [] - for node in graph.graph.nodes: - if ( - node.op == "call_module" and - type(graph.get_submodule(node.target)).__name__ in targets - ): - target_nodes.append(node) - - return target_nodes - - -def check_assumption(graph: Graph) -> bool: - for node in graph.nodes: - for user in node.users: - if node not in user.all_input_nodes: - return False - - for input_node in node.all_input_nodes: - if node not in input_node.users: - return False - - if ( - len(node.users) != len(set(node.users)) or - len(node.all_input_nodes) != len(set(node.all_input_nodes)) - ): - return False - - return True - - -def topological_partition(graph: GraphModule, target_nodes: Set[Node]) -> List[List[Node]]: - # use list representation to maintain topological sorting - assert check_assumption(graph.graph) - - partitions: List[List[Node]] = [[]] - remaining_indegrees = {node: len(node.all_input_nodes) for node in graph.graph.nodes} - partition_index = 0 # global counter, not necessary but ensures partitions are connected - - # start with graph input nodes - queue = deque(node for node in graph.graph.nodes if remaining_indegrees[node] == 0) - while len(queue) > 0: - node = queue.popleft() - - # guarantee targets are assigned to disjoint partitions - if node in target_nodes: - partition_index += 1 - partitions.append([]) - - # assign to partition - partitions[partition_index].append(node) - - # recurse on last indegree only in order to guarantee that - # the node is assigned to maximal partition - for user in node.users: - remaining_indegrees[user] -= 1 - if remaining_indegrees[user] == 0: - queue.append(user) - - assert set().union(*partitions) == set(graph.graph.nodes) - return partitions - - -def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): - subgraphs = [] - - # create subgraphs - for partition_nodes in partitions: - # create a new graph for the partition - subgraph = Graph(model) - node_map = {} - - # add placeholders for inputs not in this subgraph. use set to deduplicate - new_input_nodes = { - input_node - for node in partition_nodes - if node.op != "get_attr" - for input_node in node.all_input_nodes - if input_node not in partition_nodes - } - for input_node in new_input_nodes: - node_map[input_node] = subgraph.placeholder(input_node.name) - - # add the nodes to subgraph - for node in partition_nodes: - node_map[node] = subgraph.node_copy(node, lambda n: node_map[n]) - - # add an output node to collect all subgraph outputs into a dictionary - if len(subgraph.find_nodes(op="output")) <= 0: - output_dict = { - node.name: node_map[node] - for node in partition_nodes - if any(user not in partition_nodes for user in node.users.keys()) - } - subgraph.output(output_dict) - - # Save the subgraph for this partition - subgraph.lint() - input_names = [node.name for node in subgraph.nodes if node.op == "placeholder"] - subgraphs.append({ - "graph": subgraph, - "code": subgraph.python_code("self"), - "input_names": input_names, - "consumed_names": [], - }) - - print([n for n in subgraph.nodes]) - assert check_assumption(subgraph) - - # populate consumed_names according to when inputs are last used - # in order to vacate the `intermediates` cache and save memory - all_input_names = set().union(*(subgraph["input_names"] for subgraph in subgraphs)) - for input_name in all_input_names: - for subgraph in reversed(subgraphs): - if input_name in subgraph["input_names"]: - subgraph["consumed_names"].append(input_name) - break - else: - assert False - - return subgraphs - - -def gptq_compress(name: str, module: torch.nn.Module, inputs: List[torch.Tensor]): - print(f"gptq_compress {name} {module} {inputs.shape}") - pass - - -class HookedModel: - def __init__(self): - self.hook_targets = [] - self.hook_target_nodes = [] - self.graph = None - self.subgraphs = [] - self.model = None - - def register_hook(self, func: Callable, targets: List[str]): - self.hook_targets.append((func, targets)) - - def init_forward(self, model: torch.nn.Module): - self.model = model - - # 1. create graph - self.graph: GraphModule = symbolic_trace(model) - - # 2. identify target nodes - for func, targets in self.hook_targets: - self.hook_target_nodes.append((func, get_target_nodes(self.graph, targets))) - - all_target_nodes = set().union(*(target_nodes for _, target_nodes in self.hook_target_nodes)) - - # 3. cut into partitions along target nodes - partitions: List[List[Node]] = topological_partition(self.graph, all_target_nodes) - self.subgraphs: List[GraphModule] = partition_graph(model, partitions) - - def forward(self, *args, **kwargs): - model_modules = {name: module for name, module in self.model.named_modules()} - - # 4. perform compression - intermediates = kwargs.copy() - for subgraph_index, subgraph in enumerate(self.subgraphs): - code = subgraph["code"] - exec(code.src, code.globals) - forward_function = code.globals.get("forward") - - inputs = {input_name: intermediates[input_name] for input_name in subgraph["input_names"]} - - # detect and call hooks - for func, target_nodes in self.hook_target_nodes: - target_nodes = set(target_node for target_node in target_nodes) - subgraph_node_names = set(node.name for node in subgraph["graph"].nodes if node.op == "call_module") - - for target_node in target_nodes: - if target_node.name in subgraph_node_names: - assert len(target_node.all_input_nodes) == 1 - - module = model_modules[target_node.target] - input_value = inputs[target_node.all_input_nodes[0].name] - func(target_node.target, module, input_value) - - if subgraph_index < len(self.subgraphs) - 1: - intermediates.update(forward_function(self.model, **inputs)) - - for consumed_name in subgraph["consumed_names"]: - del intermediates[consumed_name] - else: - return forward_function(self.model, **inputs) - - -if __name__ == "__main__": - use_dummy_model = False - sequence_length = 2048 - - if use_dummy_model: - model = Model() - from torch.fx import symbolic_trace - else: - model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") - from transformers.utils.fx import symbolic_trace - - data_loader = [ - {"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, - #{"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, - #{"input_ids": torch.zeros(sequence_length, dtype=torch.int32).reshape(1, sequence_length)}, - ] - - # modifier inits - hooked_model = HookedModel() - hooked_model.register_hook(gptq_compress, ["Linear"]) - - # some time after modifier inits but before forward passes - hooked_model.init_forward(model) - - # oneshot/ eval loop - model.eval() - with torch.no_grad(): - for batch in data_loader: - hooked_output = hooked_model.forward(**batch) - model_output = model.forward(**batch) - assert torch.equal(hooked_output["last_hidden_state"], model_output["last_hidden_state"]) \ No newline at end of file diff --git a/llava.py b/llava.py deleted file mode 100644 index 81a6a58d6..000000000 --- a/llava.py +++ /dev/null @@ -1,108 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot -import os - -# Load model. -model_id = "llava-hf/llava-1.5-7b-hf" - -model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") -processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - -print("Loading dataset") -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:512]" - -NUM_CALIBRATION_SAMPLES = 1 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - -print("Preprocessing samples") -def preprocess(example): - messages = [ - [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What does the image show?"} - ] - } - ], - ] - return { - "text": processor.apply_chat_template( - messages, - add_generation_prompt=True, - ), - } - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - tmp = processor( - sample["image"], - sample["text"], - add_special_tokens=False, - return_tensors="pt" - ) - - # Remove batch dimension from each key - input_ids = tmp["input_ids"].squeeze(0) - attention_mask = tmp["attention_mask"].squeeze(0) - pixel_values = [tmp["pixel_values"][0][0].squeeze(0)] - - return { - "input_ids": torch.LongTensor(input_ids), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - } - - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -print("Setting up quantization params") -# Configure the quantization algorithm and scheme. -# In this case, we: -# * quantize the weights to fp8 with per channel via ptq -# * quantize the activations to fp8 with dynamic per token -#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] -#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], -ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] - -recipe = [ - # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), -] - -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") -oneshot( - model=model, - tokenizer=model_id, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - output_dir=save_path, -) - -#processor.save_pretrained(save_path) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(processor.decode(output[0])) -print("==========================================") diff --git a/mllama.py b/mllama.py deleted file mode 100644 index 55c2d10a3..000000000 --- a/mllama.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot -import os - -# Load model. -model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -#model = MllamaForConditionalGeneration.from_pretrained(model_id) -#model_id = "mgoin/pixtral-12b" -#model_id = "Qwen/Qwen2-VL-2B-Instruct" - -#model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") -#model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") -#model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype="auto") -model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") -processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - -print("Loading dataset") -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:512]" - -NUM_CALIBRATION_SAMPLES = 1 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - -print("Preprocessing samples") -def preprocess(example): - messages = [ - [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What does the image show?"} - ] - } - ], - ] - return { - "text": processor.apply_chat_template( - messages, - add_generation_prompt=True, - ), - } - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - tmp = processor( - sample["image"], - sample["text"], - add_special_tokens=False, - return_tensors="pt" - ) - - # Remove batch dimension from each key - input_ids = tmp["input_ids"].squeeze(0) - attention_mask = tmp["attention_mask"].squeeze(0) - pixel_values = [tmp["pixel_values"][0][0].squeeze(0)] - aspect_ratio_ids = tmp["aspect_ratio_ids"].squeeze(0) - aspect_ratio_mask = tmp["aspect_ratio_mask"].squeeze(0) - cross_attention_mask = tmp["cross_attention_mask"].squeeze(0) - - return { - "input_ids": torch.LongTensor(input_ids), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "aspect_ratio_ids": aspect_ratio_ids, - "aspect_ratio_mask": aspect_ratio_mask, - "cross_attention_mask": cross_attention_mask, - } - - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -print("Setting up quantization params") -# Configure the quantization algorithm and scheme. -# In this case, we: -# * quantize the weights to fp8 with per channel via ptq -# * quantize the activations to fp8 with dynamic per token -#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] -#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], -ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] - -recipe = [ - # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), -] - -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") -oneshot( - model=model, - tokenizer=model_id, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - output_dir=save_path, -) - -#processor.save_pretrained(save_path) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(processor.decode(output[0])) -print("==========================================") diff --git a/shubhra.py b/shubhra.py deleted file mode 100644 index 1a8dc088f..000000000 --- a/shubhra.py +++ /dev/null @@ -1,114 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot -import os - -# Load model. -model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -#model = MllamaForConditionalGeneration.from_pretrained(model_id) -#model_id = "mgoin/pixtral-12b" -#model_id = "Qwen/Qwen2-VL-2B-Instruct" - -#model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") -#model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") -#model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype="auto") -model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") -processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - -print("Loading dataset") -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:512]" - -NUM_CALIBRATION_SAMPLES = 1 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - -print("Preprocessing samples") -def preprocess(example): - messages = [ - [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What does the image show?"} - ] - } - ], - ] - return { - "text": processor.apply_chat_template( - messages, - add_generation_prompt=True, - ), - } - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - tmp = processor( - sample["image"], - sample["text"], - add_special_tokens=False, - return_tensors="pt" - ) - - # Remove batch dimension from each key - input_ids = tmp["input_ids"].squeeze(0) - attention_mask = tmp["attention_mask"].squeeze(0) - #pixel_values = [tmp["pixel_values"][0][0].squeeze(0)] - - return { - "input_ids": torch.LongTensor(input_ids), - "attention_mask": attention_mask, - #"pixel_values": pixel_values, - } - - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -print("Setting up quantization params") -# Configure the quantization algorithm and scheme. -# In this case, we: -# * quantize the weights to fp8 with per channel via ptq -# * quantize the activations to fp8 with dynamic per token -#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] -#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], -ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] - -recipe = [ - # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), -] - -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") -oneshot( - model=model, - tokenizer=model_id, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - output_dir=save_path, -) - -#processor.save_pretrained(save_path) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(processor.decode(output[0])) -print("==========================================") From a62617c64d42697be3a5a76e06ab8de947d5d860 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 28 Nov 2024 16:02:37 +0000 Subject: [PATCH 110/285] clean up CustomDataset Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/custom.py | 75 ++++++++++++------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py index e849594e7..7b74f30ac 100644 --- a/src/llmcompressor/transformers/finetune/data/custom.py +++ b/src/llmcompressor/transformers/finetune/data/custom.py @@ -44,56 +44,73 @@ def __init__(self, data_args, split, tokenizer): split=split, tokenizer=tokenizer, ) - self.preprocessing_func = data_args.preprocessing_func - self.remove_columns = data_args.remove_columns def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]: """Get the raw dataset and apply preprocessing func if provided""" - dataset = self.data_args.dataset - if isinstance(dataset, DatasetDict) or isinstance(dataset, Dataset): - # user passed in an already instantiated dataset, just use it directly - raw_dataset = dataset - else: - # dataset must be loaded from file or HF Hub - raw_dataset = super().get_raw_dataset() - - if self.preprocessing_func is not None: - if callable(self.preprocessing_func): - func = self.preprocessing_func - elif ":" in self.preprocessing_func: + # load dataset + dataset = ( + self.data_args.dataset + if isinstance(self.data_args.dataset, (DatasetDict, Dataset)) + else super().get_raw_dataset() # load dataset from file or HF Hub + ) + + # preprocess dataset + dataset = self._preprocess_dataset(dataset) + dataset = self._remove_columns_from_dataset(dataset) + + return dataset + + def _preprocess_dataset( + self, dataset: Union[DatasetDict, Dataset] + ) -> Union[DatasetDict, Dataset]: + preprocessing_func = self.data_args.preprocessing_func + + if preprocessing_func is not None: + if callable(preprocessing_func): + pass + + elif ":" in preprocessing_func: # load func_name from "/path/to/file.py:func_name" - func = import_from_path(self.preprocessing_func) + preprocessing_func = import_from_path(preprocessing_func) else: # load from the registry - func = PreprocessingFunctionRegistry.get_value_from_registry( - name=self.preprocessing_func + preprocessing_func = ( + PreprocessingFunctionRegistry.get_value_from_registry( + name=preprocessing_func + ) ) - raw_dataset = self.map( - raw_dataset, - function=func, + dataset = self.map( + dataset, + function=preprocessing_func, batched=False, num_proc=self.data_args.preprocessing_num_workers, desc="Applying custom func to the custom dataset", ) - self.remove_columns = ( - self.remove_columns or self.get_remove_columns_from_dataset(raw_dataset) - ) + return dataset + + def _remove_columns_from_dataset( + self, dataset: Union[DatasetDict, Dataset] + ) -> Union[DatasetDict, Dataset]: + remove_columns = self.data_args.remove_columns + + if not remove_columns: + remove_columns = self._get_remove_columns_from_dataset(dataset) - if self.remove_columns is not None: - raw_dataset = self.map( - raw_dataset, + if remove_columns is not None: + dataset = self.map( + dataset, batched=True, - remove_columns=self.remove_columns, + remove_columns=remove_columns, num_proc=self.data_args.preprocessing_num_workers, desc="Removing unneeded columns", ) - return raw_dataset + return dataset - def get_remove_columns_from_dataset( + def _get_remove_columns_from_dataset( self, raw_dataset: Union[DatasetDict, Dataset] ) -> List[str]: """Remove redandant columns from the dataset for processing""" From 57b5e025a2141d8201f86029b0dcbecabd216fbc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 29 Nov 2024 06:06:40 +0000 Subject: [PATCH 111/285] chchchchanges Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/base.py | 324 ++++++++++-------- .../transformers/finetune/data/c4.py | 10 +- .../finetune/data/cnn_dailymail.py | 38 +- .../transformers/finetune/data/custom.py | 100 +----- .../transformers/finetune/data/data_args.py | 15 +- .../finetune/data/evolcodealpaca.py | 43 +-- .../transformers/finetune/data/gsm8k.py | 40 +-- .../finetune/data/open_platypus.py | 55 +-- .../finetune/data/ultrachat_200k.py | 42 +-- .../transformers/finetune/data/wikitext.py | 4 +- .../transformers/finetune/model_args.py | 3 + .../transformers/finetune/runner.py | 40 +-- .../transformers/finetune/text_generation.py | 35 +- 13 files changed, 290 insertions(+), 459 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index d4c3a6222..fd692face 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,9 +1,10 @@ -from typing import Optional, Union +from functools import cached_property +from typing import Any, Callable, Union from compressed_tensors.registry import RegistryMixin -from datasets import Dataset, IterableDataset +from datasets import Dataset, DatasetDict, IterableDataset from loguru import logger -from transformers import AutoTokenizer +from transformers import AutoProcessor from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( @@ -11,6 +12,12 @@ get_custom_datasets_from_path, get_raw_dataset, ) +from llmcompressor.transformers.utils.preprocessing_functions import ( + PreprocessingFunctionRegistry, +) +from llmcompressor.utils import import_from_path + +DatasetType = Union[Dataset, DatasetDict, IterableDataset] class TextGenerationDataset(RegistryMixin): @@ -23,63 +30,119 @@ class TextGenerationDataset(RegistryMixin): :param tokenizer: tokenizer to use on dataset """ + # used to mask out the prompt so prompt tokens do not contribute to training loss PROMPT_KEY = "prompt" + # TODO: not sure how to handle the prompt stuff best. Specifically w.r.t. + """ + dataset = self.processor(**dataset) + + if dataset includes the PROMPT_KEY + """ + def __init__( self, - text_column: str, data_args: DataTrainingArguments, split: str, - tokenizer: AutoTokenizer, + processor: "AutoProcessor", ): - self.text_column = text_column - self.tokenizer = tokenizer self.data_args = data_args - self.raw_kwargs = data_args.raw_kwargs or {} self.split = split - self.dvc_dataset = ( - True if self.data_args.dvc_data_repository is not None else False - ) - self.custom_dataset = True if self.data_args.dataset_path is not None else False + self.processor = processor - # configure padding - if data_args.concatenate_data: - self.padding = False - elif data_args.pad_to_max_length: - self.padding = "max_length" - else: - self.padding = False + # get tokenizer + self.tokenizer = getattr(self.processor, "tokenizer", self.processor) - if self.tokenizer: - if not self.tokenizer.pad_token: - self.tokenizer.pad_token = self.tokenizer.eos_token + # fill in pad token + if self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token # configure sequence length max_seq_length = data_args.max_seq_length - model_max_length = tokenizer.model_max_length if tokenizer else max_seq_length - if self.tokenizer and max_seq_length > model_max_length: + if data_args.max_seq_length > self.tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({max_seq_length}) is larger than " - f"the maximum length for the model ({tokenizer.model_max_length}). " - f"Using max_seq_length={tokenizer.model_max_length}." + f"the maximum length for model ({self.tokenizer.model_max_length}). " + f"Using max_seq_length={self.tokenizer.model_max_length}." + ) + self.max_seq_length = min( + data_args.max_seq_length, self.tokenizer.model_max_length + ) + + def __call__(self, add_labels: bool) -> DatasetType: + dataset = self.data_args.dataset + + # load dataset + if isinstance(dataset, str): + dataset = self.load_dataset() + + # preprocess + if self.preprocess is not None: + dataset = self.map( + dataset, + self.preprocess, + batched=False, + remove_columns=dataset.column_names, + num_proc=self.data_args.preprocessing_num_workers, + desc="Preprocessing", + ) + + # tokenize + if "input_ids" not in dataset.column_names: + dataset = self.map( + dataset, + self.tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=self.data_args.preprocessing_num_workers, + load_from_cache_file=not self.data_args.overwrite_cache, + desc="Tokenizing", ) - self.max_seq_length = min(data_args.max_seq_length, model_max_length) - def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset: + # postprocess + + if self.data_args.concatenate_data: + # group text + dataset = self.map( + dataset, + function=self.group_text, + batched=True, + num_proc=self.data_args.preprocessing_num_workers, + load_from_cache_file=not self.data_args.overwrite_cache, + desc="Concatenating data", + ) + + if add_labels: + # add labels + dataset = self.map( + dataset, + function=self.add_labels, + batched=False, # not compatible with batching, need row lengths + num_proc=self.data_args.preprocessing_num_workers, + load_from_cache_file=not self.data_args.overwrite_cache, + desc="Adding labels", + ) + + if self.PROMPT_KEY in dataset: + del dataset[self.PROMPT_KEY] + + return dataset + + def load_dataset(self): """ Load the raw dataset from Hugging Face, using cached copy if available :param cache_dir: disk location to search for cached dataset :return: the requested dataset """ - if self.custom_dataset: - if self.dvc_dataset: - self.raw_kwargs["storage_options"] = { + if self.data_args.dataset_path is not None: + if self.data_args.dvc_data_repository is not None: + self.data_args.raw_kwargs["storage_options"] = { "url": self.data_args.dvc_data_repository } - self.raw_kwargs["data_files"] = self.data_args.dataset_path + self.data_args.raw_kwargs["data_files"] = self.data_args.dataset_path else: - self.raw_kwargs["data_files"] = get_custom_datasets_from_path( + self.data_args.raw_kwargs["data_files"] = get_custom_datasets_from_path( self.data_args.dataset_path, self.data_args.dataset if hasattr(self.data_args, "dataset") @@ -88,129 +151,98 @@ def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset: return get_raw_dataset( self.data_args, - cache_dir, + None, split=self.split, streaming=self.data_args.streaming, - **self.raw_kwargs, + **self.data_args.raw_kwargs, ) - def tokenize_and_process( - self, raw_dataset: Optional[Dataset] = None, add_labels: Optional[bool] = True - ) -> Dataset: - """ - Sets up the raw dataset for finetuning, performs tokenization, concatenates - entries to max sequence length if desired, and adds labels to each entry + @cached_property + def preprocess(self) -> Union[Callable[[Any], Any], None]: + preprocessing_func = self.data_args.preprocessing_func - :param raw_dataset: dataset to process - :param add_labels: whether to include labels in tokenized output - """ + if callable(preprocessing_func): + return preprocessing_func - # helper fn for tokenizing text column - def tokenize_fn(data): - result = self.tokenizer( - data[self.text_column], - padding=self.padding, - max_length=self.max_seq_length, - truncation=True, - ) + if isinstance(preprocessing_func, str): + if ":" in preprocessing_func: + # load func_name from "/path/to/file.py:func_name" + return import_from_path(preprocessing_func) + else: + # load from the registry + return PreprocessingFunctionRegistry.get_value_from_registry( + name=preprocessing_func + ) - # store unpadded prompt so we can mask out correct number of elements - # in the labels - if self.PROMPT_KEY in data: - result[self.PROMPT_KEY] = self.tokenizer( - data[self.PROMPT_KEY], - max_length=self.max_seq_length, - truncation=True, - )["input_ids"] - - return result - - # helper fn for filling to max_sequence_length by concatenating entries - def group_text_fn(data): - concatenated_data = {k: sum(data[k], []) for k in data.keys()} - total_length = len(concatenated_data[list(data.keys())[0]]) - total_length = (total_length // self.max_seq_length) * self.max_seq_length - result = { - k: [ - t[i : i + self.max_seq_length] - for i in range(0, total_length, self.max_seq_length) - ] - for k, t in concatenated_data.items() - } - return result - - # helper fn for adding labels, needed for loss calculation - def label_fn(data): - # if the dataset uses prompts, mask them out so they don't contribute - # to the loss calculation - prompt_len = 0 - if self.PROMPT_KEY in data: - prompt_len = len(data[self.PROMPT_KEY]) - data["labels"] = data["input_ids"].copy() - data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len - - # mask out padding in the labels as well - padding = len(data["attention_mask"]) - sum(data["attention_mask"]) - if padding > 0: - data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding - return data - - if raw_dataset is None: - raw_dataset = self.get_raw_dataset() - - dataset = self.map( - raw_dataset, - function=tokenize_fn, - batched=True, - remove_columns=[self.text_column], - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Running tokenizer on dataset", - ) + return self.dataset_template - if self.data_args.concatenate_data: - dataset = self.map( - dataset, - function=group_text_fn, - batched=True, - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Grouping text", - ) + @property + def dataset_template(self) -> Union[Callable[[Any], Any], None]: + return None - if isinstance(dataset, IterableDataset): - # so we can get column names from streamed_datasets - dataset = dataset._resolve_features() + def tokenize(self, dataset: DatasetType) -> DatasetType: + # manually swap text argument if specified + if self.data_args.text_column not in dataset.column_names: + dataset["text"] = dataset[self.data_args.text_column] - column_names = dataset.column_names - if isinstance(column_names, dict): - column_names = column_names[list(column_names)[0]] + # tokenize + dataset = self.processor( + **dataset, + padding=( + False + if self.data_args.concatenate_data + else "max_length" + if self.data_args.pad_to_max_length + else False + ), + max_length=self.max_seq_length, + truncation=True, + ) - if add_labels: - dataset = self.map( - dataset, - function=label_fn, - batched=False, # not compatible with batching, need row lengths - remove_columns=[self.PROMPT_KEY] - if self.PROMPT_KEY in column_names - else None, - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Adding labels", - ) - else: - dataset = self.map( - dataset, - batched=False, # not compatible with batching, need row lengths - remove_columns=[self.PROMPT_KEY] - if self.PROMPT_KEY in column_names - else None, - ) + # store unpadded prompt so we can mask out correct number of elements + # in the labels + if self.PROMPT_KEY in dataset: + dataset[self.PROMPT_KEY] = self.processor( + dataset[self.PROMPT_KEY], + max_length=self.max_seq_length, + truncation=True, + )["input_ids"] return dataset + def group_text(self, data): + concatenated_data = {k: sum(data[k], []) for k in data.keys()} + total_length = len(concatenated_data[list(data.keys())[0]]) + total_length = (total_length // self.max_seq_length) * self.max_seq_length + result = { + k: [ + t[i : i + self.max_seq_length] + for i in range(0, total_length, self.max_seq_length) + ] + for k, t in concatenated_data.items() + } + return result + + def add_labels(self, data): + # if the dataset uses prompts, mask them out so they don't contribute + # to the loss calculation + prompt_len = 0 + if self.PROMPT_KEY in data: + prompt_len = len(data[self.PROMPT_KEY]) + data["labels"] = data["input_ids"].copy() + data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len + + # mask out padding in the labels as well + padding = len(data["attention_mask"]) - sum(data["attention_mask"]) + if padding > 0: + data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding + return data + def map( - self, dataset: Union[Dataset, IterableDataset], **kwargs + self, + dataset: Union[Dataset, IterableDataset], + map_fn: Union[Callable[[Any], Any], None], + **kwargs, ) -> Union[Dataset, IterableDataset]: """ Wrapper function around Dataset.map and IterableDataset.map, clears invalid @@ -220,10 +252,18 @@ def map( :param kwargs: args to pass on to map function :return: mapped dataset """ + if map_fn is None: + return dataset + if isinstance(dataset, IterableDataset): # remove arguments that don't apply to streaming kwargs.pop("num_proc", None) kwargs.pop("load_from_cache_file", None) kwargs.pop("desc", None) - return dataset.map(**kwargs) + dataset = dataset.map(map_fn, **kwargs) + + if isinstance(dataset, IterableDataset): + dataset = dataset._resolve_features() + + return dataset diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index 37eeceae6..e7f82814e 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -10,12 +10,12 @@ class C4Dataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor to use on dataset """ - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "allenai/c4" - super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer - ) + data_args.text_column = "text" + + super().__init__(data_args=data_args, split=split, processor=processor) diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 64755de4a..02c266656 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Optional from llmcompressor.transformers.finetune.data import TextGenerationDataset @@ -24,44 +23,21 @@ class CNNDailyMailDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor to use on dataset """ SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "cnn_dailymail" data_args.dataset_config_name = "3.0.0" - super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer - ) + super().__init__(data_args=data_args, split=split, processor=processor) - def get_raw_dataset(self, cache_dir: Optional[str] = None): - """ - Load the raw dataset from Hugging Face, using cached copy if available. - Additionally reformats the entries to fit the template. - - :param cache_dir: disk location to search for cached dataset - :return: the requested dataset - """ - raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) - - def restructure_fn(sample): - sample["text"] = self.SAMPLE_TEMPLATE.format( + def dataset_template(self, sample): + return { + "text": self.SAMPLE_TEMPLATE.format( article=sample["article"], highlights=sample["highlights"] ) - - return sample - - raw_dataset = self.map( - raw_dataset, - function=restructure_fn, - batched=False, - remove_columns=["article", "highlights", "id"], - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Restructuring CNN/DailyMail Dataset", - ) - return raw_dataset + } diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py index 7b74f30ac..175d13468 100644 --- a/src/llmcompressor/transformers/finetune/data/custom.py +++ b/src/llmcompressor/transformers/finetune/data/custom.py @@ -11,16 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from typing import Dict, List, Union - -from datasets.dataset_dict import Dataset, DatasetDict - from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.utils.preprocessing_functions import ( - PreprocessingFunctionRegistry, -) -from llmcompressor.utils import import_from_path @TextGenerationDataset.register(name="custom", alias=["json", "csv"]) @@ -36,93 +27,4 @@ class CustomDataset(TextGenerationDataset): """ - def __init__(self, data_args, split, tokenizer): - data_args = deepcopy(data_args) - super().__init__( - text_column=data_args.text_column, - data_args=data_args, - split=split, - tokenizer=tokenizer, - ) - - def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]: - """Get the raw dataset and apply preprocessing func if provided""" - - # load dataset - dataset = ( - self.data_args.dataset - if isinstance(self.data_args.dataset, (DatasetDict, Dataset)) - else super().get_raw_dataset() # load dataset from file or HF Hub - ) - - # preprocess dataset - dataset = self._preprocess_dataset(dataset) - dataset = self._remove_columns_from_dataset(dataset) - - return dataset - - def _preprocess_dataset( - self, dataset: Union[DatasetDict, Dataset] - ) -> Union[DatasetDict, Dataset]: - preprocessing_func = self.data_args.preprocessing_func - - if preprocessing_func is not None: - if callable(preprocessing_func): - pass - - elif ":" in preprocessing_func: - # load func_name from "/path/to/file.py:func_name" - preprocessing_func = import_from_path(preprocessing_func) - else: - # load from the registry - preprocessing_func = ( - PreprocessingFunctionRegistry.get_value_from_registry( - name=preprocessing_func - ) - ) - - dataset = self.map( - dataset, - function=preprocessing_func, - batched=False, - num_proc=self.data_args.preprocessing_num_workers, - desc="Applying custom func to the custom dataset", - ) - - return dataset - - def _remove_columns_from_dataset( - self, dataset: Union[DatasetDict, Dataset] - ) -> Union[DatasetDict, Dataset]: - remove_columns = self.data_args.remove_columns - - if not remove_columns: - remove_columns = self._get_remove_columns_from_dataset(dataset) - - if remove_columns is not None: - dataset = self.map( - dataset, - batched=True, - remove_columns=remove_columns, - num_proc=self.data_args.preprocessing_num_workers, - desc="Removing unneeded columns", - ) - - return dataset - - def _get_remove_columns_from_dataset( - self, raw_dataset: Union[DatasetDict, Dataset] - ) -> List[str]: - """Remove redandant columns from the dataset for processing""" - - remove_columns = raw_dataset.column_names - if isinstance(remove_columns, Dict): - remove_columns = raw_dataset[list(raw_dataset.keys())[0]].column_names - - remove_columns = set(remove_columns) - if self.text_column in remove_columns: - remove_columns.remove(self.text_column) - if self.PROMPT_KEY in remove_columns: - remove_columns.remove(self.PROMPT_KEY) - - return list(remove_columns) + pass diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index d79ef55fa..9623d413a 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -31,22 +31,25 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): }, ) - text_column: Optional[str] = field( + text_column: str = field( default="text", metadata={"help": "For custom datasets only. The text field key"}, ) remove_columns: Union[None, str, List] = field( default=None, - metadata={"help": "Column names to remove after preprocessing custom datasets"}, + metadata={ + "help": "For custom datasets only. Column names to remove after " + "preprocessing custom datasets" + }, ) preprocessing_func: Union[None, str, Callable] = field( default=None, metadata={ "help": ( - "The preprocessing function to apply or the preprocessing func name in " - "src/llmcompressor/transformers/utils/preprocessing_functions.py" + "For custom datasets only. Either a function to apply to the dataset, " + "a path to a function definition of the form /path/to/file.py:func" ) }, ) @@ -91,8 +94,8 @@ class DataTrainingArguments(CustomDataTrainingArguments): "help": "Whether or not to concatenate datapoints to fill max_seq_length" }, ) - raw_kwargs: Optional[Dict] = field( - default=None, + raw_kwargs: Dict = field( + default_factory=dict, metadata={"help": "Additional keyboard args to pass to datasets load_data"}, ) splits: Union[None, str, List, Dict] = field( diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 9529d3115..4e96409e8 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Optional from llmcompressor.transformers.finetune.data import TextGenerationDataset @@ -37,37 +36,17 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): def __init__(self, data_args, split, tokenizer): data_args = deepcopy(data_args) data_args.dataset = "theblackcat102/evol-codealpaca-v1" - super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer - ) + data_args.text_column = "text" - def get_raw_dataset(self, cache_dir: Optional[str] = None): - """ - Load the raw dataset from Hugging Face, using cached copy if available. - Additionally reformats the entries to fit the alpaca template. + super().__init__(data_args, split=split, tokenizer=tokenizer) - :param cache_dir: disk location to search for cached dataset - :return: the requested dataset - """ - raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) + def dataset_template(self, sample): + prompt = self.EVOL_ALPACA_TEMPLATE.format(instruction=sample["instruction"]) + text = prompt + if "output" in text: + text += sample["output"] - # helper fn for restructuring each dataset entry using the alpaca template - def restructure_fn(sample): - sample["text"] = self.EVOL_ALPACA_TEMPLATE.format( - instruction=sample["instruction"] - ) - sample[self.PROMPT_KEY] = sample["text"] - if "output" in sample: - sample["text"] += sample["output"] - return sample - - raw_dataset = self.map( - raw_dataset, - function=restructure_fn, - batched=False, - remove_columns=["output", "instruction"], - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Restructuring Evol Code Alpaca Dataset", - ) - return raw_dataset + return { + "text": text, + self.PROMPT_KEY: prompt, + } diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index f9a94bcf4..a4e8f41cd 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import Optional from llmcompressor.transformers.finetune.data import TextGenerationDataset @@ -19,35 +18,18 @@ class GSM8KDataset(TextGenerationDataset): def __init__(self, data_args, split, tokenizer): data_args = deepcopy(data_args) data_args.dataset = "gsm8k" + super().__init__( text_column="text", data_args=data_args, split=split, tokenizer=tokenizer ) - def get_raw_dataset(self, cache_dir: Optional[str] = None): - """ - Load the raw dataset from Hugging Face, using cached copy if available. - Additionally reformats the entries to fit the alpaca template. - - :param cache_dir: disk location to search for cached dataset - :return: the requested dataset - """ - raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) - - # helper fn for restructuring each dataset entry using the gsm template - def restructure_fn(sample): - sample["text"] = self.GSM_TEMPLATE.format(question=sample["question"]) - sample[self.PROMPT_KEY] = sample["text"] - if "answer" in sample: - sample["text"] += " " + sample["answer"] - return sample - - raw_dataset = self.map( - raw_dataset, - function=restructure_fn, - batched=False, - remove_columns=["question", "answer"], - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Restructuring GSM Dataset", - ) - return raw_dataset + def dataset_template(self, sample): + prompt = self.GSM_TEMPLATE.format(question=sample["question"]) + text = prompt + if "answer" in sample: + text += " " + sample["answer"] + + return { + "text": text, + self.PROMPT_KEY: prompt, + } diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 55e54cbce..4b85331db 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Optional from llmcompressor.transformers.finetune.data import TextGenerationDataset @@ -44,39 +43,21 @@ def __init__(self, data_args, split, tokenizer): text_column="text", data_args=data_args, split=split, tokenizer=tokenizer ) - def get_raw_dataset(self, cache_dir: Optional[str] = None): - """ - Load the raw dataset from Hugging Face, using cached copy if available. - Additionally reformats the entries to fit the alpaca template. - - :param cache_dir: disk location to search for cached dataset - :return: the requested dataset - """ - raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) - - # helper fn for restructuring each dataset entry using the alpaca template - def restructure_fn(sample): - if "input" in sample and sample["input"] != "": - sample["text"] = self.ALPACA_TEMPLATE["prompt_input"].format( - instruction=sample["instruction"], input=sample["input"] - ) - else: - sample["text"] = self.ALPACA_TEMPLATE["prompt_no_input"].format( - instruction=sample["instruction"] - ) - - sample[self.PROMPT_KEY] = sample["text"] - if "output" in sample: - sample["text"] += sample["output"] - return sample - - raw_dataset = self.map( - raw_dataset, - function=restructure_fn, - batched=False, - remove_columns=["input", "output", "instruction", "data_source"], - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Restructuring Platypus Dataset", - ) - return raw_dataset + def dataset_template(self, sample): + if "input" in sample and sample["input"] != "": + prompt = self.ALPACA_TEMPLATE["prompt_input"].format( + instruction=sample["instruction"], input=sample["input"] + ) + else: + prompt = self.ALPACA_TEMPLATE["prompt_no_input"].format( + instruction=sample["instruction"] + ) + + text = prompt + if "output" in sample: + text += sample["output"] + + return { + "text": text, + self.PROMPT_KEY: prompt, + } diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 5b2e66ab5..0b71d6893 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Optional from llmcompressor.transformers.finetune.data import TextGenerationDataset @@ -43,50 +42,27 @@ class UltraChatDataset(TextGenerationDataset): def __init__(self, data_args, split, tokenizer): data_args = deepcopy(data_args) data_args.dataset = "HuggingFaceH4/ultrachat_200k" + data_args.text_column = "messages" if split in ["train", "test"]: split += "_sft" super().__init__( - text_column="messages", data_args=data_args, split=split, tokenizer=tokenizer, ) - if ( - not hasattr(self.tokenizer, "chat_template") - or self.tokenizer.chat_template is None - ): + if getattr(self.tokenizer, "chat_template", None) is None: self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE - def get_raw_dataset(self, cache_dir: Optional[str] = None): - """ - Load the raw dataset from Hugging Face, using cached copy if available. - Additionally reformats the entries to fit the alpaca template. + def dataset_template(self, sample): + messages = sample["messages"] - :param cache_dir: disk location to search for cached dataset - :return: the requested dataset - """ - raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) + if messages[0]["role"] != "system": + messages.insert(0, {"role": "system", "content": ""}) - # helper fn for restructuring each dataset entry using the chat template - def restructure_fn(sample): - if sample["messages"][0]["role"] != "system": - sample["messages"].insert(0, {"role": "system", "content": ""}) - - sample["messages"] = self.tokenizer.apply_chat_template( - sample["messages"], tokenize=False, add_generation_prompt=False - ) - return sample - - raw_dataset = self.map( - raw_dataset, - function=restructure_fn, - batched=False, - remove_columns=[], - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Restructuring Ultra Chat Dataset", + text = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False ) - return raw_dataset + return {"text": text} diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index 034d58ba2..01a9c1ed0 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -12,6 +12,4 @@ class WikiTextDataset(TextGenerationDataset): """ def __init__(self, data_args, split, tokenizer): - super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer - ) + super().__init__(data_args=data_args, split=split, tokenizer=tokenizer) diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/finetune/model_args.py index d3d8e974f..606d440cb 100644 --- a/src/llmcompressor/transformers/finetune/model_args.py +++ b/src/llmcompressor/transformers/finetune/model_args.py @@ -34,6 +34,9 @@ class ModelArguments: "help": "Pretrained tokenizer name or path if not the same as model_name" }, ) + processor: Optional[str] = field( + default=None, + ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where to store the pretrained data from huggingface.co"}, diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 6344b1a2b..530c5198e 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -1,12 +1,12 @@ import math import os import re -from typing import List, Optional +from typing import List, Optional, Union import torch from loguru import logger from torch.utils.data import Dataset -from transformers import AutoTokenizer +from transformers import AutoProcessor, PreTrainedTokenizerBase from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import ( @@ -56,11 +56,14 @@ def __init__( self.datasets = {} self.trainer = None - self.tokenizer = None self.parent_output_dir = self._training_args.output_dir self._output_dir = self._training_args.output_dir - def populate_datasets(self, tokenizer: "AutoTokenizer", add_labels: bool = True): + def populate_datasets( + self, + processor: Union["AutoProcessor", PreTrainedTokenizerBase], + add_labels: bool = True, + ): """ Loads datasets for each flow based on data_args, stores a Dataset for each enabled flow in self.datasets @@ -68,7 +71,7 @@ def populate_datasets(self, tokenizer: "AutoTokenizer", add_labels: bool = True) :param tokenizer: tokenizer to use for dataset tokenization """ if self._data_args.dataset is None: - self.tokenizer = self._model_args.tokenizer + self.tokenizer = self._model_args.processor logger.info( "Running oneshot without calibration data. This is expected for " "weight-only and dynamic quantization" @@ -93,29 +96,19 @@ def _get_split_name(inp_str): splits = {_get_split_name(s): s for s in splits} # default to custom dataset if dataset provided isn't a string - registry_id = self._data_args.dataset - - if not isinstance(registry_id, str): - registry_id = "custom" + registry_id = ( + self._data_args.dataset + if isinstance(self._data_args.dataset, str) + else "custom" + ) for split_name, split_str in splits.items(): dataset_manager = TextGenerationDataset.load_from_registry( - registry_id, + name=registry_id, data_args=self._data_args, split=split_str, - tokenizer=tokenizer, + processor=processor, ) - - dataset = self._data_args.dataset - if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: - # dataset is already tokenized - tokenized_datasets[split_name] = dataset - else: - # dataset needs to be tokenized - raw_dataset = dataset_manager.get_raw_dataset() - tokenized_dataset = dataset_manager.tokenize_and_process( - raw_dataset, add_labels=add_labels - ) - tokenized_datasets[split_name] = tokenized_dataset + tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels) self.datasets = make_dataset_splits( tokenized_datasets, @@ -124,7 +117,6 @@ def _get_split_name(inp_str): do_predict=self._training_args.do_predict, do_oneshot=self._training_args.do_oneshot, ) - self.tokenizer = tokenizer def get_dataset_split(self, split_name: str) -> Dataset: """ diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 85aa6d82c..fdae1c0f1 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -24,7 +24,7 @@ from transformers import ( AutoConfig, AutoModelForCausalLM, - AutoTokenizer, + AutoProcessor, DefaultDataCollator, HfArgumentParser, set_seed, @@ -226,11 +226,11 @@ def initialize_model_from_path( return teacher, model_path, model -def initialize_tokenizer_from_path(model_args, model, teacher): - tokenizer_src = model_args.tokenizer - tokenizer_src = tokenizer_src or get_shared_tokenizer_src(model, teacher) - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_src, +def initialize_processor_from_path(model_args, model, teacher): + processor_src = model_args.processor + processor_src = processor_src or get_shared_tokenizer_src(model, teacher) + processor = AutoProcessor.from_pretrained( + processor_src, cache_dir=model_args.cache_dir, use_fast=True, revision=model_args.model_revision, @@ -238,7 +238,7 @@ def initialize_tokenizer_from_path(model_args, model, teacher): trust_remote_code=model_args.trust_remote_code_model, ) - return tokenizer + return processor def main( @@ -299,11 +299,9 @@ def main( # Detecting last checkpoint. last_checkpoint = None teacher = model_args.distill_teacher - model = model_args.model - # Load tokenizer - # distill TODO: support for different tokenizer for teacher? - tokenizer = model_args.tokenizer + # distill TODO: support for different processor for teacher? + model = model_args.model if isinstance(model, str) or isinstance(model, PosixPath): (teacher, _model_path, model) = initialize_model_from_path( model_args, @@ -317,8 +315,9 @@ def main( if teacher is not None: teacher.eval() - if isinstance(tokenizer, str) or tokenizer is None: - tokenizer = initialize_tokenizer_from_path(model_args, model, teacher) + processor = model_args.processor + if isinstance(processor, str) or processor is None: + processor = initialize_processor_from_path(model_args, model, teacher) pre_initialize_structure(model=model) @@ -330,7 +329,7 @@ def main( model_args=model_args, data_args=data_args, training_args=training_args ) add_labels = training_args.do_train or training_args.run_stages - stage_runner.populate_datasets(tokenizer=tokenizer, add_labels=add_labels) + stage_runner.populate_datasets(processor=processor, add_labels=add_labels) train_dataset = stage_runner.get_dataset_split("train") eval_dataset = stage_runner.get_dataset_split("validation") calib_dataset = stage_runner.get_dataset_split("calibration") @@ -346,13 +345,13 @@ def main( data_args=data_args, train_dataset=train_dataset or calib_dataset, eval_dataset=eval_dataset, - tokenizer=tokenizer, + processing_class=processor, data_collator=data_collator, ) # wrap model.save_pretrained if is_fsdp_model(model): - modify_fsdp_model_save_pretrained(trainer, tokenizer) + modify_fsdp_model_save_pretrained(trainer, processor) else: modify_save_pretrained(model) @@ -396,8 +395,8 @@ def main( model.save_pretrained( training_args.output_dir, save_compressed=training_args.save_compressed ) - if tokenizer is not None: - tokenizer.save_pretrained(training_args.output_dir) + if processor is not None: + processor.save_pretrained(training_args.output_dir) # Clean up the CompressionSession before exit if requested if training_args.clear_sparse_session: From fa317fdb8be0557fe26eeb80be4763e415476d2d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 2 Dec 2024 21:34:59 +0000 Subject: [PATCH 112/285] wip: use rename to processor, going through tests Signed-off-by: Kyle Sayers --- .../pytorch/model_load/helpers.py | 9 +- .../transformers/finetune/data/base.py | 158 +++++++++++------- .../transformers/finetune/data/c4.py | 11 +- .../finetune/data/cnn_dailymail.py | 24 +-- .../finetune/data/evolcodealpaca.py | 26 ++- .../transformers/finetune/data/gsm8k.py | 16 +- .../finetune/data/open_platypus.py | 16 +- .../transformers/finetune/data/ptb.py | 16 +- .../finetune/data/ultrachat_200k.py | 26 ++- .../transformers/finetune/data/wikitext.py | 23 ++- .../transformers/finetune/runner.py | 13 +- .../compressed_tensors_utils.py | 4 +- .../utils/preprocessing_functions.py | 7 +- src/llmcompressor/utils/__init__.py | 1 + src/llmcompressor/utils/fsdp/helpers.py | 6 +- src/llmcompressor/utils/typing.py | 6 + .../finetune/data/test_dataset_loading.py | 56 ++++--- .../finetune/data/test_registry.py | 12 +- 18 files changed, 256 insertions(+), 174 deletions(-) create mode 100644 src/llmcompressor/utils/typing.py diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 11a924f1d..44b7f1bf7 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -9,6 +9,7 @@ from llmcompressor.core import active_session, create_session, pre_initialize_structure from llmcompressor.pytorch.utils import ModuleSparsificationInfo +from llmcompressor.utils import Processor COMPLETED_STAGES_FILENAME = "completed_stages.json" @@ -94,7 +95,7 @@ def initialize_recipe(model: Module, recipe_path: str): def save_model_and_recipe( model: Module, save_path: str, - tokenizer: Optional[Any] = None, + processor: Optional[Processor] = None, save_safetensors: bool = False, save_compressed: bool = False, ): @@ -102,7 +103,7 @@ def save_model_and_recipe( Save a model, tokenizer and the currently loaded recipe to file :param model: pytorch model to save :param save_path: path to save output to - :param tokenizer: model tokenizer to save + :param processor: model processor or tokenizer to save :param save_safetensors: whether to save as safetensors or pickle (bin) :param save_compressed: whether to compress sparse weights on disk """ @@ -111,8 +112,8 @@ def save_model_and_recipe( save_path, save_compressed=save_compressed, safe_serialization=save_safetensors ) - if tokenizer is not None: - tokenizer.save_pretrained(save_path) + if processor is not None: + processor.save_pretrained(save_path) logger.info("Saving output to {}".format(os.path.abspath(save_path))) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index fd692face..aaab3189e 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,10 +1,9 @@ from functools import cached_property -from typing import Any, Callable, Union +from typing import Any, Callable, Dict, List, Optional, Union from compressed_tensors.registry import RegistryMixin from datasets import Dataset, DatasetDict, IterableDataset from loguru import logger -from transformers import AutoProcessor from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( @@ -15,7 +14,7 @@ from llmcompressor.transformers.utils.preprocessing_functions import ( PreprocessingFunctionRegistry, ) -from llmcompressor.utils import import_from_path +from llmcompressor.utils import Processor, import_from_path DatasetType = Union[Dataset, DatasetDict, IterableDataset] @@ -24,10 +23,9 @@ class TextGenerationDataset(RegistryMixin): """ Base class for text datasets, handles tokenization and dataset splits - :param text_column: name of column corresponding to text in the dataset :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ # used to mask out the prompt so prompt tokens do not contribute to training loss @@ -44,7 +42,7 @@ def __init__( self, data_args: DataTrainingArguments, split: str, - processor: "AutoProcessor", + processor: Processor, ): self.data_args = data_args self.split = split @@ -53,31 +51,47 @@ def __init__( # get tokenizer self.tokenizer = getattr(self.processor, "tokenizer", self.processor) - # fill in pad token - if self.tokenizer.pad_token: - self.tokenizer.pad_token = self.tokenizer.eos_token - - # configure sequence length - max_seq_length = data_args.max_seq_length - if data_args.max_seq_length > self.tokenizer.model_max_length: - logger.warning( - f"The max_seq_length passed ({max_seq_length}) is larger than " - f"the maximum length for model ({self.tokenizer.model_max_length}). " - f"Using max_seq_length={self.tokenizer.model_max_length}." + if self.tokenizer is not None: + # fill in pad token + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # configure sequence length + max_seq_length = data_args.max_seq_length + if data_args.max_seq_length > self.tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({max_seq_length}) is larger than " + f"maximum length for model ({self.tokenizer.model_max_length}). " + f"Using max_seq_length={self.tokenizer.model_max_length}." + ) + self.max_seq_length = min( + data_args.max_seq_length, self.tokenizer.model_max_length ) - self.max_seq_length = min( - data_args.max_seq_length, self.tokenizer.model_max_length - ) - def __call__(self, add_labels: bool) -> DatasetType: + # configure padding + self.padding = ( + False + if self.data_args.concatenate_data + else "max_length" + if self.data_args.pad_to_max_length + else False + ) + + else: + self.max_seq_length = None + self.padding = False + + def __call__(self, add_labels: bool = True) -> DatasetType: dataset = self.data_args.dataset - # load dataset + # 1. Load if isinstance(dataset, str): + # load dataset from huggingface or disk dataset = self.load_dataset() - # preprocess + # 2. Preprocess if self.preprocess is not None: + # apply template or preprocessing function dataset = self.map( dataset, self.preprocess, @@ -87,8 +101,13 @@ def __call__(self, add_labels: bool) -> DatasetType: desc="Preprocessing", ) - # tokenize - if "input_ids" not in dataset.column_names: + # rename and remove columns match processor kwargs + dataset = self.rename_columns(dataset) + + # 3. Process + if self.processor is not None and "input_ids" not in dataset.column_names: + + # tokenize/ process dataset = self.map( dataset, self.tokenize, @@ -99,13 +118,12 @@ def __call__(self, add_labels: bool) -> DatasetType: desc="Tokenizing", ) - # postprocess - + # 4. Postprocess if self.data_args.concatenate_data: - # group text + # postprocess: group text dataset = self.map( dataset, - function=self.group_text, + self.group_text, batched=True, num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, @@ -113,17 +131,17 @@ def __call__(self, add_labels: bool) -> DatasetType: ) if add_labels: - # add labels + # postprocess: add labels dataset = self.map( dataset, - function=self.add_labels, + self.add_labels, batched=False, # not compatible with batching, need row lengths num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, desc="Adding labels", ) - if self.PROMPT_KEY in dataset: + elif self.PROMPT_KEY in dataset.column_names: del dataset[self.PROMPT_KEY] return dataset @@ -159,6 +177,11 @@ def load_dataset(self): @cached_property def preprocess(self) -> Union[Callable[[Any], Any], None]: + """ + + The function must return keys which correspond to tokenizer kwargs, optionally + including PROMPT_KEY + """ preprocessing_func = self.data_args.preprocessing_func if callable(preprocessing_func): @@ -180,37 +203,39 @@ def preprocess(self) -> Union[Callable[[Any], Any], None]: def dataset_template(self) -> Union[Callable[[Any], Any], None]: return None - def tokenize(self, dataset: DatasetType) -> DatasetType: - # manually swap text argument if specified - if self.data_args.text_column not in dataset.column_names: - dataset["text"] = dataset[self.data_args.text_column] + def rename_columns(self, dataset: DatasetType) -> DatasetType: + # rename columns to match processor/tokenizer kwargs + if ( + self.data_args.text_column != "text" + and self.data_args.text_column in dataset.column_names + ): + dataset = dataset.rename_column(self.data_args.text_column, "text") + + return dataset + + def tokenize(self, data: Dict[str, Any]) -> Dict[str, Any]: + # separate prompt + prompt = data.pop(self.PROMPT_KEY, None) # tokenize - dataset = self.processor( - **dataset, - padding=( - False - if self.data_args.concatenate_data - else "max_length" - if self.data_args.pad_to_max_length - else False - ), + data = self.processor( + **data, + padding=self.padding, max_length=self.max_seq_length, truncation=True, ) - # store unpadded prompt so we can mask out correct number of elements - # in the labels - if self.PROMPT_KEY in dataset: - dataset[self.PROMPT_KEY] = self.processor( - dataset[self.PROMPT_KEY], + # store unpadded prompt so we can mask out correct number of elements in labels + if prompt is not None: + data[self.PROMPT_KEY] = self.processor( + prompt, max_length=self.max_seq_length, truncation=True, )["input_ids"] - return dataset + return data - def group_text(self, data): + def group_text(self, data: Dict[str, Any]) -> Dict[str, Any]: concatenated_data = {k: sum(data[k], []) for k in data.keys()} total_length = len(concatenated_data[list(data.keys())[0]]) total_length = (total_length // self.max_seq_length) * self.max_seq_length @@ -241,18 +266,17 @@ def add_labels(self, data): def map( self, dataset: Union[Dataset, IterableDataset], - map_fn: Union[Callable[[Any], Any], None], + function: Union[Callable[[Any], Any], None], + remove_columns: Optional[Union[str, List[str], Dict[str, List[str]]]] = None, **kwargs, ) -> Union[Dataset, IterableDataset]: """ - Wrapper function around Dataset.map and IterableDataset.map, clears invalid - parameters in the case where streaming is enabled + Wrapper function around Dataset.map and IterableDataset.map - :param dataset: dataset to apply mapping to - :param kwargs: args to pass on to map function - :return: mapped dataset + 1. Clears invalid parameters in the case where streaming is enabled + 2. Skips removing columns which were already removed after mapping """ - if map_fn is None: + if function is None: return dataset if isinstance(dataset, IterableDataset): @@ -261,9 +285,23 @@ def map( kwargs.pop("load_from_cache_file", None) kwargs.pop("desc", None) - dataset = dataset.map(map_fn, **kwargs) + dataset = dataset.map(function, **kwargs) if isinstance(dataset, IterableDataset): dataset = dataset._resolve_features() + if remove_columns is not None: + if isinstance(remove_columns, str): + remove_columns = [remove_columns] + + dataset_column_names = dataset.column_names + if isinstance(dataset_column_names, dict): + dataset_column_names = sum(dataset_column_names.values(), []) + if isinstance(remove_columns, dict): + remove_columns = sum(remove_columns.values(), []) + + dataset = dataset.remove_columns( + list(set(dataset_column_names) & set(remove_columns)) + ) + return dataset diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index e7f82814e..fa89b9883 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -1,6 +1,8 @@ from copy import deepcopy from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.utils import Processor @TextGenerationDataset.register(name="c4") @@ -10,10 +12,15 @@ class C4Dataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param processor: processor to use on dataset + :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args, split, processor): + def __init__( + self, + data_args: DataTrainingArguments, + split: str, + processor: Processor, + ): data_args = deepcopy(data_args) data_args.dataset = "allenai/c4" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 02c266656..739126054 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -1,19 +1,8 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from copy import deepcopy from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.utils import Processor @TextGenerationDataset.register(name="cnn_dailymail") @@ -23,12 +12,17 @@ class CNNDailyMailDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param processor: processor to use on dataset + :param processor: processor or tokenizer to use on dataset """ SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args, split, processor): + def __init__( + self, + data_args: DataTrainingArguments, + split: str, + processor: Processor, + ): data_args = deepcopy(data_args) data_args.dataset = "cnn_dailymail" data_args.dataset_config_name = "3.0.0" diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 4e96409e8..60657c6a3 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -1,19 +1,8 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from copy import deepcopy from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.utils import Processor @TextGenerationDataset.register(name="evolcodealpaca") @@ -23,7 +12,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ EVOL_ALPACA_TEMPLATE = ( @@ -33,12 +22,17 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__(self, data_args, split, tokenizer): + def __init__( + self, + data_args: DataTrainingArguments, + split: str, + processor: Processor, + ): data_args = deepcopy(data_args) data_args.dataset = "theblackcat102/evol-codealpaca-v1" data_args.text_column = "text" - super().__init__(data_args, split=split, tokenizer=tokenizer) + super().__init__(data_args, split=split, processor=processor) def dataset_template(self, sample): prompt = self.EVOL_ALPACA_TEMPLATE.format(instruction=sample["instruction"]) diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index a4e8f41cd..d1b27b9b3 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -1,6 +1,8 @@ from copy import deepcopy from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.utils import Processor @TextGenerationDataset.register(name="gsm8k") @@ -10,18 +12,22 @@ class GSM8KDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__(self, data_args, split, tokenizer): + def __init__( + self, + data_args: DataTrainingArguments, + split: str, + processor: Processor, + ): data_args = deepcopy(data_args) data_args.dataset = "gsm8k" + data_args.text_column = "text" - super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer - ) + super().__init__(data_args=data_args, split=split, processor=processor) def dataset_template(self, sample): prompt = self.GSM_TEMPLATE.format(question=sample["question"]) diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 4b85331db..88a7b02fc 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -14,6 +14,8 @@ from copy import deepcopy from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.utils import Processor @TextGenerationDataset.register(name="open_platypus") @@ -23,7 +25,7 @@ class OpenPlatypusDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ ALPACA_TEMPLATE = { @@ -36,12 +38,16 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__(self, data_args, split, tokenizer): + def __init__( + self, + data_args: DataTrainingArguments, + split: str, + processor: Processor, + ): data_args = deepcopy(data_args) data_args.dataset = "garage-bAInd/Open-Platypus" - super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer - ) + data_args.text_column = "text" + super().__init__(data_args=data_args, split=split, processor=processor) def dataset_template(self, sample): if "input" in sample and sample["input"] != "": diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index 6f502edaf..78a6d865b 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -1,6 +1,8 @@ from copy import deepcopy from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.utils import Processor @TextGenerationDataset.register(name="ptb") @@ -10,15 +12,21 @@ class PtbDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args, split, tokenizer): + def __init__( + self, + data_args: DataTrainingArguments, + split: str, + processor: Processor, + ): data_args = deepcopy(data_args) data_args.dataset = "ptb_text_only" + data_args.text_column = "sentence" + super().__init__( - text_column="sentence", data_args=data_args, split=split, - tokenizer=tokenizer, + processor=processor, ) diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 0b71d6893..1b7fc3f0f 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -1,19 +1,8 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from copy import deepcopy from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.utils import Processor @TextGenerationDataset.register(name="ultrachat_200k") @@ -23,7 +12,7 @@ class UltraChatDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ DEFAULT_CHAT_TEMPLATE = ( @@ -39,7 +28,12 @@ class UltraChatDataset(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args, split, tokenizer): + def __init__( + self, + data_args: DataTrainingArguments, + split: str, + processor: Processor, + ): data_args = deepcopy(data_args) data_args.dataset = "HuggingFaceH4/ultrachat_200k" data_args.text_column = "messages" @@ -50,7 +44,7 @@ def __init__(self, data_args, split, tokenizer): super().__init__( data_args=data_args, split=split, - tokenizer=tokenizer, + processor=processor, ) if getattr(self.tokenizer, "chat_template", None) is None: diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index 01a9c1ed0..15a6b8fcf 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -1,4 +1,8 @@ +from copy import deepcopy + from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.utils import Processor @TextGenerationDataset.register(name="wikitext") @@ -8,8 +12,21 @@ class WikiTextDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args, split, tokenizer): - super().__init__(data_args=data_args, split=split, tokenizer=tokenizer) + def __init__( + self, + data_args: DataTrainingArguments, + split: str, + processor: Processor, + ): + data_args = deepcopy(data_args) + data_args.dataset = "Salesforce/wikitext" + data_args.text_column = "text" + + super().__init__( + data_args=data_args, + split=split, + processor=processor, + ) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 530c5198e..cae7ae18f 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -1,12 +1,11 @@ import math import os import re -from typing import List, Optional, Union +from typing import List, Optional import torch from loguru import logger from torch.utils.data import Dataset -from transformers import AutoProcessor, PreTrainedTokenizerBase from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import ( @@ -24,6 +23,7 @@ ) from llmcompressor.transformers.finetune.model_args import ModelArguments from llmcompressor.transformers.finetune.training_args import TrainingArguments +from llmcompressor.utils import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe @@ -61,17 +61,18 @@ def __init__( def populate_datasets( self, - processor: Union["AutoProcessor", PreTrainedTokenizerBase], + processor: Processor, add_labels: bool = True, ): """ Loads datasets for each flow based on data_args, stores a Dataset for each enabled flow in self.datasets - :param tokenizer: tokenizer to use for dataset tokenization + :param processor: processor or tokenizer to use for dataset tokenization + :param add_labels: if True, add labels column to dataset splits """ if self._data_args.dataset is None: - self.tokenizer = self._model_args.processor + self.processor = self._model_args.processor logger.info( "Running oneshot without calibration data. This is expected for " "weight-only and dynamic quantization" @@ -258,7 +259,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): save_model_and_recipe( model=self.trainer.model, save_path=self._output_dir, - tokenizer=self.tokenizer, + processor=self.processor, save_safetensors=self._training_args.save_safetensors, save_compressed=self._training_args.save_compressed, ) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 6de89dd8b..4140377bd 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -32,7 +32,7 @@ __all__ = ["modify_save_pretrained", "modify_fsdp_model_save_pretrained"] -def modify_fsdp_model_save_pretrained(trainer, tokenizer): +def modify_fsdp_model_save_pretrained(trainer, processor): """ Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that supports compression for fsdp model @@ -77,7 +77,7 @@ def save_pretrained_wrapper( model=trainer.model, accelerator=trainer.accelerator, output_dir=save_directory, - tokenizer=tokenizer, + processor=processor, ) # only allow the main process move the state # dicts to cpu diff --git a/src/llmcompressor/transformers/utils/preprocessing_functions.py b/src/llmcompressor/transformers/utils/preprocessing_functions.py index cadec88f0..6bf6ade42 100644 --- a/src/llmcompressor/transformers/utils/preprocessing_functions.py +++ b/src/llmcompressor/transformers/utils/preprocessing_functions.py @@ -1,14 +1,17 @@ -from typing import Dict +from typing import TYPE_CHECKING, Dict from compressed_tensors.registry import RegistryMixin +if TYPE_CHECKING: + from llmcompressor.transformers.finetune.data.base import TextGenerationDataset + class PreprocessingFunctionRegistry(RegistryMixin): pass @PreprocessingFunctionRegistry.register() -def custom_evolved_codealpaca_dataset(data: Dict): +def custom_evolved_codealpaca_dataset(self: "TextGenerationDataset", data: Dict): PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:""" data["prompt"] = PROMPT_DICT.format_map(data) data["text"] = data["prompt"] + data["output"] diff --git a/src/llmcompressor/utils/__init__.py b/src/llmcompressor/utils/__init__.py index 98d5e1c65..d36055acb 100644 --- a/src/llmcompressor/utils/__init__.py +++ b/src/llmcompressor/utils/__init__.py @@ -5,3 +5,4 @@ # flake8: noqa from .helpers import * +from .typing import * diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 8cc0f5405..32a4be3f7 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -71,7 +71,7 @@ def set_wrapped_model(state: State, wrapped_model: Module): state.model = wrapped_model -def unwrap_and_export_model(model, accelerator, output_dir, tokenizer): +def unwrap_and_export_model(model, accelerator, output_dir, processor): """ Recursively unwraps an FSDP model, then saves the unwrapped model and the currently active recipe to disk @@ -79,7 +79,7 @@ def unwrap_and_export_model(model, accelerator, output_dir, tokenizer): :param model: model to unwrap :param accelerator: Accelerator instance used to perform unwrapping :param output_dir: where to save output model - :param tokenizer: tokenizer used by the model + :param processor: processor used by the model """ full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FullyShardedDataParallel.state_dict_type( @@ -95,7 +95,7 @@ def unwrap_and_export_model(model, accelerator, output_dir, tokenizer): save_model_and_recipe( model=unwrapped_model, save_path=output_dir, - tokenizer=tokenizer, + processor=processor, ) diff --git a/src/llmcompressor/utils/typing.py b/src/llmcompressor/utils/typing.py new file mode 100644 index 000000000..fa311ecb7 --- /dev/null +++ b/src/llmcompressor/utils/typing.py @@ -0,0 +1,6 @@ +from typing import Any, Union + +from transformers import PreTrainedTokenizer + +# Tokenizer or Processor. Processors do not inherit from a unified base class +Processor = Union[PreTrainedTokenizer, Any] diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index 3415858af..6bd0312b1 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -29,13 +29,13 @@ def test_concatenation_tokenization(self): self.data_args.dataset, data_args=self.data_args, split="train[:5%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) - raw_dataset = wiki_manager.get_raw_dataset() + raw_dataset = wiki_manager.load_dataset() self.assertGreater(len(raw_dataset), 0) self.assertEqual(raw_dataset.split, "train[:5%]") self.assertEqual(raw_dataset.info.config_name, "wikitext-2-raw-v1") - tokenized_dataset = wiki_manager.tokenize_and_process(raw_dataset) + tokenized_dataset = wiki_manager() self.assertIn("input_ids", tokenized_dataset.features) self.assertIn("labels", tokenized_dataset.features) for i in range(len(tokenized_dataset)): @@ -61,15 +61,22 @@ def test_no_padding_tokenization(self): self.data_args.dataset, data_args=self.data_args, split="train[5%:10%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) - raw_dataset = op_manager.get_raw_dataset() - self.assertGreater(len(raw_dataset), 0) - ex_item = raw_dataset[0]["text"] + dataset = op_manager.load_dataset() # load + dataset = op_manager.map( # preprocess + dataset, + op_manager.preprocess, + batched=False, + num_proc=op_manager.data_args.preprocessing_num_workers, + ) + dataset = op_manager.rename_columns(dataset) # rename + self.assertGreater(len(dataset), 0) + ex_item = dataset[0]["text"] self.assertIn("Below is an instruction that describes a task", ex_item) - self.assertEqual(raw_dataset.split, "train[5%:10%]") - tokenized_dataset = op_manager.tokenize_and_process(raw_dataset) + self.assertEqual(dataset.split, "train[5%:10%]") + tokenized_dataset = op_manager() self.assertIn("input_ids", tokenized_dataset.features) self.assertIn("labels", tokenized_dataset.features) print(tokenized_dataset[0]["input_ids"]) @@ -96,7 +103,7 @@ def test_max_seq_len_clipped(self): self.data_args.dataset, data_args=self.data_args, split="train[80%:]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) self.assertEqual( @@ -125,17 +132,17 @@ def test_dataset_kwargs_and_percentages(self): self.data_args.dataset, data_args=self.data_args, split="train[5%:10%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) - raw_dataset_a = c4_manager_a.get_raw_dataset() + raw_dataset_a = c4_manager_a.load_dataset() c4_manager_b = TextGenerationDataset.load_from_registry( self.data_args.dataset, data_args=self.data_args, split="train[5%:15%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) - raw_dataset_b = c4_manager_b.get_raw_dataset() + raw_dataset_b = c4_manager_b.load_dataset() self.assertEqual(len(raw_dataset_b), 2 * len(raw_dataset_a)) @@ -164,14 +171,14 @@ def test_datasets(self, dataset_key, dataset_config, split, do_concat): data_args.dataset, data_args=data_args, split=split, - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) - raw_dataset = manager.get_raw_dataset() + raw_dataset = manager.load_dataset() self.assertGreater(len(raw_dataset), 0) self.assertEqual(raw_dataset.split, split) self.assertEqual(raw_dataset.info.config_name, dataset_config) - tokenized_dataset = manager.tokenize_and_process(raw_dataset) + tokenized_dataset = manager() self.assertIn("input_ids", tokenized_dataset.features) self.assertIn("labels", tokenized_dataset.features) for i in range(len(tokenized_dataset)): @@ -204,13 +211,13 @@ def test_evol(self): self.data_args.dataset, data_args=self.data_args, split="train[:2%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) - raw_dataset = evol_manager.get_raw_dataset() + raw_dataset = evol_manager.load_dataset() self.assertGreater(len(raw_dataset), 0) self.assertEqual(raw_dataset.split, "train[:2%]") - tokenized_dataset = evol_manager.tokenize_and_process(raw_dataset) + tokenized_dataset = evol_manager() self.assertIn("input_ids", tokenized_dataset.features) self.assertIn("labels", tokenized_dataset.features) for i in range(len(tokenized_dataset)): @@ -238,11 +245,10 @@ def test_stream_loading(self): self.data_args.dataset, data_args=self.data_args, split="train", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) - raw_dataset = manager.get_raw_dataset() - processed = manager.tokenize_and_process(raw_dataset) + processed = manager() self.assertIsInstance(processed, IterableDataset) with pytest.raises(TypeError): # in streaming mode we don't know the length of the dataset @@ -276,7 +282,7 @@ def test_split_loading(self, split_def): stage_runner = StageRunner( model_args=model_args, data_args=data_args, training_args=training_args ) - stage_runner.populate_datasets(tokenizer=self.tiny_llama_tokenizer) + stage_runner.populate_datasets(processor=self.tiny_llama_tokenizer) train_dataset = stage_runner.get_dataset_split("train") assert train_dataset is not None @@ -320,7 +326,7 @@ def preprocess(sample): ), training_args=TrainingArguments(do_oneshot=True), ) - stage_runner.populate_datasets(tokenizer=None) + stage_runner.populate_datasets(processor=None) calib_dataset = stage_runner.get_dataset_split("calibration") self.assertEqual(len(calib_dataset), self.num_calib_samples) data_cols = calib_dataset.column_names diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index e4c804c07..9aee4c20f 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -16,11 +16,11 @@ def test_c4_initializes(tiny_llama_tokenizer): data_args.dataset, data_args=data_args, split=None, - tokenizer=tiny_llama_tokenizer, + processor=tiny_llama_tokenizer, ) assert isinstance(c4_manager, TextGenerationDataset) assert isinstance(c4_manager, C4Dataset) - assert c4_manager.text_column == "text" + assert c4_manager.data_args.text_column == "text" assert not c4_manager.padding assert c4_manager.max_seq_length == data_args.max_seq_length @@ -34,11 +34,11 @@ def test_wikitext_initializes(tiny_llama_tokenizer): data_args.dataset, data_args=data_args, split=None, - tokenizer=tiny_llama_tokenizer, + processor=tiny_llama_tokenizer, ) assert isinstance(wiki_manager, TextGenerationDataset) assert isinstance(wiki_manager, WikiTextDataset) - assert wiki_manager.text_column == "text" + assert wiki_manager.data_args.text_column == "text" assert wiki_manager.padding == "max_length" assert wiki_manager.max_seq_length == data_args.max_seq_length @@ -50,10 +50,10 @@ def test_open_platypus_initializes(tiny_llama_tokenizer): data_args.dataset, data_args=data_args, split=None, - tokenizer=tiny_llama_tokenizer, + processor=tiny_llama_tokenizer, ) assert isinstance(op_manager, TextGenerationDataset) assert isinstance(op_manager, OpenPlatypusDataset) - assert op_manager.text_column == "text" + assert op_manager.data_args.text_column == "text" assert not op_manager.padding assert op_manager.max_seq_length == data_args.max_seq_length From f3f5875d123245f69a0a2f8c8a25eec4aa26952d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 2 Dec 2024 21:48:34 +0000 Subject: [PATCH 113/285] remove labels from calibration dataset rather than assuming that all tokenized datasets should not be given labels Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/base.py | 1 - .../transformers/finetune/data/data_helpers.py | 7 +++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index aaab3189e..cf2fcee85 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -106,7 +106,6 @@ def __call__(self, add_labels: bool = True) -> DatasetType: # 3. Process if self.processor is not None and "input_ids" not in dataset.column_names: - # tokenize/ process dataset = self.map( dataset, diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..1b0b589f9 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -137,6 +137,13 @@ def make_dataset_splits( raise ValueError("--do_oneshot requires a calibration dataset") calib_split = tokenized_datasets["train"] + # remove labels from calibration dataset + column_names = calib_split.column_names + if isinstance(column_names, dict): + column_names = sum(column_names.values(), []) + if "labels" in column_names: + calib_split = calib_split.remove_columns("labels") + split_datasets = { "train": train_split, "validation": eval_split, From 58c3afe20bd3af2654d9a36f1e8a3ec9f8ae6415 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 2 Dec 2024 21:52:26 +0000 Subject: [PATCH 114/285] cleanup Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/c4.py | 9 ++------- .../transformers/finetune/data/cnn_dailymail.py | 9 ++------- .../transformers/finetune/data/evolcodealpaca.py | 9 ++------- src/llmcompressor/transformers/finetune/data/gsm8k.py | 9 ++------- .../transformers/finetune/data/open_platypus.py | 9 ++------- src/llmcompressor/transformers/finetune/data/wikitext.py | 9 ++------- src/llmcompressor/transformers/finetune/runner.py | 8 ++------ 7 files changed, 14 insertions(+), 48 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index fa89b9883..ab890ff86 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -1,7 +1,7 @@ from copy import deepcopy +from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.utils import Processor @@ -15,12 +15,7 @@ class C4Dataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__( - self, - data_args: DataTrainingArguments, - split: str, - processor: Processor, - ): + def __init__(self, data_args: DataArgs, split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "allenai/c4" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 739126054..39757cd76 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -1,7 +1,7 @@ from copy import deepcopy +from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.utils import Processor @@ -17,12 +17,7 @@ class CNNDailyMailDataset(TextGenerationDataset): SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__( - self, - data_args: DataTrainingArguments, - split: str, - processor: Processor, - ): + def __init__(self, data_args: DataArgs, split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "cnn_dailymail" data_args.dataset_config_name = "3.0.0" diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 60657c6a3..ce25858e5 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -1,7 +1,7 @@ from copy import deepcopy +from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.utils import Processor @@ -22,12 +22,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__( - self, - data_args: DataTrainingArguments, - split: str, - processor: Processor, - ): + def __init__(self, data_args: DataArgs, split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "theblackcat102/evol-codealpaca-v1" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index d1b27b9b3..60e4d8da4 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -1,7 +1,7 @@ from copy import deepcopy +from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.utils import Processor @@ -17,12 +17,7 @@ class GSM8KDataset(TextGenerationDataset): GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__( - self, - data_args: DataTrainingArguments, - split: str, - processor: Processor, - ): + def __init__(self, data_args: DataArgs, split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "gsm8k" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 88a7b02fc..6eb78a592 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -13,8 +13,8 @@ # limitations under the License. from copy import deepcopy +from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.utils import Processor @@ -38,12 +38,7 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__( - self, - data_args: DataTrainingArguments, - split: str, - processor: Processor, - ): + def __init__(self, data_args: DataArgs, split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "garage-bAInd/Open-Platypus" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index 15a6b8fcf..4640fb90c 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -1,7 +1,7 @@ from copy import deepcopy +from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.utils import Processor @@ -15,12 +15,7 @@ class WikiTextDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__( - self, - data_args: DataTrainingArguments, - split: str, - processor: Processor, - ): + def __init__(self, data_args: DataArgs, split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "Salesforce/wikitext" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index cae7ae18f..6a89e8a7a 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -59,11 +59,7 @@ def __init__( self.parent_output_dir = self._training_args.output_dir self._output_dir = self._training_args.output_dir - def populate_datasets( - self, - processor: Processor, - add_labels: bool = True, - ): + def populate_datasets(self, processor: Processor, add_labels: bool = True): """ Loads datasets for each flow based on data_args, stores a Dataset for each enabled flow in self.datasets @@ -104,7 +100,7 @@ def _get_split_name(inp_str): ) for split_name, split_str in splits.items(): dataset_manager = TextGenerationDataset.load_from_registry( - name=registry_id, + registry_id, data_args=self._data_args, split=split_str, processor=processor, From 72aecfcadb4237dc0f8250725981d473d406e24a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 2 Dec 2024 22:14:08 +0000 Subject: [PATCH 115/285] cleanup, etc Signed-off-by: Kyle Sayers --- src/llmcompressor/pytorch/model_load/helpers.py | 3 ++- .../transformers/finetune/data/base.py | 10 +++------- .../transformers/finetune/model_args.py | 1 + .../transformers/finetune/session_mixin.py | 5 +++-- .../transformers/finetune/text_generation.py | 4 ++-- .../transformers/sparsification/sparse_model.py | 8 ++++---- src/llmcompressor/transformers/utils/helpers.py | 1 + src/llmcompressor/utils/typing.py | 13 ++++++++++--- 8 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 44b7f1bf7..65757ac8f 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -100,7 +100,8 @@ def save_model_and_recipe( save_compressed: bool = False, ): """ - Save a model, tokenizer and the currently loaded recipe to file + Save a model, processor and the currently loaded recipe to file + :param model: pytorch model to save :param save_path: path to save output to :param processor: model processor or tokenizer to save diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index cf2fcee85..b1d822eeb 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -84,14 +84,12 @@ def __init__( def __call__(self, add_labels: bool = True) -> DatasetType: dataset = self.data_args.dataset - # 1. Load if isinstance(dataset, str): - # load dataset from huggingface or disk + # load dataset: load from huggingface or disk dataset = self.load_dataset() - # 2. Preprocess if self.preprocess is not None: - # apply template or preprocessing function + # preprocess: apply template or preprocessing function dataset = self.map( dataset, self.preprocess, @@ -104,7 +102,6 @@ def __call__(self, add_labels: bool = True) -> DatasetType: # rename and remove columns match processor kwargs dataset = self.rename_columns(dataset) - # 3. Process if self.processor is not None and "input_ids" not in dataset.column_names: # tokenize/ process dataset = self.map( @@ -117,7 +114,6 @@ def __call__(self, add_labels: bool = True) -> DatasetType: desc="Tokenizing", ) - # 4. Postprocess if self.data_args.concatenate_data: # postprocess: group text dataset = self.map( @@ -141,7 +137,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: ) elif self.PROMPT_KEY in dataset.column_names: - del dataset[self.PROMPT_KEY] + dataset.remove_columns(self.PROMPT_KEY) return dataset diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/finetune/model_args.py index 606d440cb..8fbcad70e 100644 --- a/src/llmcompressor/transformers/finetune/model_args.py +++ b/src/llmcompressor/transformers/finetune/model_args.py @@ -28,6 +28,7 @@ class ModelArguments: "help": "Pretrained config name or path if not the same as model_name" }, ) + # TODO: depreciate tokenizer: Optional[str] = field( default=None, metadata={ diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index e3a9c4d84..04faec349 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -486,8 +486,9 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): ) self.save_state() - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) + processor = getattr(self, "processing_class", self.tokenizer) + if processor is not None: + processor.save_pretrained(output_dir) if not self.recipe: return diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index fdae1c0f1..1d7abe3c0 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -49,7 +49,7 @@ patch_tied_tensors_bug, ) from llmcompressor.transformers.sparsification.sparse_model import ( - get_shared_tokenizer_src, + get_shared_processor_src, ) from llmcompressor.transformers.utils.helpers import detect_last_checkpoint from llmcompressor.utils.fsdp.helpers import is_fsdp_model @@ -228,7 +228,7 @@ def initialize_model_from_path( def initialize_processor_from_path(model_args, model, teacher): processor_src = model_args.processor - processor_src = processor_src or get_shared_tokenizer_src(model, teacher) + processor_src = processor_src or get_shared_processor_src(model, teacher) processor = AutoProcessor.from_pretrained( processor_src, cache_dir=model_args.cache_dir, diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index bf09396d7..d7abc323a 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -7,7 +7,7 @@ __all__ = [ "SparseAutoModelForCausalLM", - "get_shared_tokenizer_src", + "get_shared_processor_src", ] @@ -20,14 +20,14 @@ def from_pretrained(*args, **kwargs): return AutoModelForCausalLM.from_pretrained(*args, **kwargs) -def get_shared_tokenizer_src(student: Module, teacher: Optional[Module]) -> str: +def get_shared_processor_src(student: Module, teacher: Optional[Module]) -> str: """ - Get a tokenizer source used for both student and teacher, assuming + Get a processor/tokenizer source used for both student and teacher, assuming that they could be shared :param student: the student model :param teacher: the teacher model - :return: the source for the tokenizer shared between teacher and model + :return: the source for the processor/tokenizer shared between teacher and model """ if teacher is not None and teacher not in ("disable", "self"): diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index 401a454cf..b74a55add 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -37,6 +37,7 @@ ] +# TODO: update these files to reflect processors class TaskNames(Enum): mlm = {"masked-language-modeling", "mlm"} qa = {"question-answering", "qa"} diff --git a/src/llmcompressor/utils/typing.py b/src/llmcompressor/utils/typing.py index fa311ecb7..2f4bf4fc8 100644 --- a/src/llmcompressor/utils/typing.py +++ b/src/llmcompressor/utils/typing.py @@ -1,6 +1,13 @@ -from typing import Any, Union +from typing import Union -from transformers import PreTrainedTokenizer +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedTokenizer, + ProcessorMixin, +) # Tokenizer or Processor. Processors do not inherit from a unified base class -Processor = Union[PreTrainedTokenizer, Any] +Processor = Union[ + PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin +] From 4461a3edede5b09934aa6b778f459e8588b020c8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 2 Dec 2024 23:02:09 +0000 Subject: [PATCH 116/285] fix typehinting Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/c4.py | 7 +++++-- .../finetune/data/cnn_dailymail.py | 7 +++++-- .../transformers/finetune/data/custom.py | 13 ------------ .../finetune/data/evolcodealpaca.py | 7 +++++-- .../transformers/finetune/data/gsm8k.py | 7 +++++-- .../finetune/data/open_platypus.py | 20 +++++-------------- .../transformers/finetune/data/ptb.py | 12 +++++------ .../finetune/data/ultrachat_200k.py | 12 +++++------ .../transformers/finetune/data/wikitext.py | 7 +++++-- 9 files changed, 40 insertions(+), 52 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index ab890ff86..7f02c0eaf 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -1,9 +1,12 @@ from copy import deepcopy +from typing import TYPE_CHECKING -from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.utils import Processor +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + @TextGenerationDataset.register(name="c4") class C4Dataset(TextGenerationDataset): @@ -15,7 +18,7 @@ class C4Dataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: DataArgs, split: str, processor: Processor): + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "allenai/c4" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 39757cd76..b4e6484a0 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -1,9 +1,12 @@ from copy import deepcopy +from typing import TYPE_CHECKING -from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.utils import Processor +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + @TextGenerationDataset.register(name="cnn_dailymail") class CNNDailyMailDataset(TextGenerationDataset): @@ -17,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset): SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args: DataArgs, split: str, processor: Processor): + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "cnn_dailymail" data_args.dataset_config_name = "3.0.0" diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py index 175d13468..84e721ba6 100644 --- a/src/llmcompressor/transformers/finetune/data/custom.py +++ b/src/llmcompressor/transformers/finetune/data/custom.py @@ -1,16 +1,3 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from llmcompressor.transformers.finetune.data import TextGenerationDataset diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index ce25858e5..784cd55fe 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -1,9 +1,12 @@ from copy import deepcopy +from typing import TYPE_CHECKING -from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.utils import Processor +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + @TextGenerationDataset.register(name="evolcodealpaca") class EvolCodeAlpacaDataset(TextGenerationDataset): @@ -22,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__(self, data_args: DataArgs, split: str, processor: Processor): + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "theblackcat102/evol-codealpaca-v1" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index 60e4d8da4..36324fa87 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -1,9 +1,12 @@ from copy import deepcopy +from typing import TYPE_CHECKING -from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.utils import Processor +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + @TextGenerationDataset.register(name="gsm8k") class GSM8KDataset(TextGenerationDataset): @@ -17,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset): GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__(self, data_args: DataArgs, split: str, processor: Processor): + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "gsm8k" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 6eb78a592..b4a136ba0 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -1,22 +1,12 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from copy import deepcopy +from typing import TYPE_CHECKING -from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.utils import Processor +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + @TextGenerationDataset.register(name="open_platypus") class OpenPlatypusDataset(TextGenerationDataset): @@ -38,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__(self, data_args: DataArgs, split: str, processor: Processor): + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "garage-bAInd/Open-Platypus" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index 78a6d865b..6b8bd3b61 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -1,9 +1,12 @@ from copy import deepcopy +from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.utils import Processor +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + @TextGenerationDataset.register(name="ptb") class PtbDataset(TextGenerationDataset): @@ -15,12 +18,7 @@ class PtbDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__( - self, - data_args: DataTrainingArguments, - split: str, - processor: Processor, - ): + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "ptb_text_only" data_args.text_column = "sentence" diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 1b7fc3f0f..c9d60bc84 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -1,9 +1,12 @@ from copy import deepcopy +from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.utils import Processor +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + @TextGenerationDataset.register(name="ultrachat_200k") class UltraChatDataset(TextGenerationDataset): @@ -28,12 +31,7 @@ class UltraChatDataset(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__( - self, - data_args: DataTrainingArguments, - split: str, - processor: Processor, - ): + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "HuggingFaceH4/ultrachat_200k" data_args.text_column = "messages" diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index 4640fb90c..7ea6953f6 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -1,9 +1,12 @@ from copy import deepcopy +from typing import TYPE_CHECKING -from llmcompressor.transformers.finetune.data import DataTrainingArguments as DataArgs from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.utils import Processor +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + @TextGenerationDataset.register(name="wikitext") class WikiTextDataset(TextGenerationDataset): @@ -15,7 +18,7 @@ class WikiTextDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: DataArgs, split: str, processor: Processor): + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "Salesforce/wikitext" data_args.text_column = "text" From fb330014082a41ea0c44b40b031cc047a824dce2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 2 Dec 2024 23:06:35 +0000 Subject: [PATCH 117/285] add typechecking imports Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/utils/helpers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index b74a55add..5efed9447 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -9,7 +9,7 @@ from contextlib import suppress from enum import Enum from pathlib import Path -from typing import Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional from typing import OrderedDict as OrderedDictType from typing import Tuple, Union @@ -24,6 +24,10 @@ from llmcompressor.utils.fsdp.context import main_process_first_context +if TYPE_CHECKING: + from llmcompressor.transformers import ModelArguments, TrainingArguments + + __all__ = [ "RECIPE_NAME", "detect_last_checkpoint", From bf4744a37c25eb8dd75e091b94ab3ef794416731 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 00:16:23 +0000 Subject: [PATCH 118/285] remove sparseml utilities Signed-off-by: Kyle Sayers --- .../compressed_tensors_utils.py | 3 +- .../transformers/utils/helpers.py | 446 +----------------- 2 files changed, 7 insertions(+), 442 deletions(-) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 6de89dd8b..88822f69e 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -24,6 +24,7 @@ from llmcompressor.transformers.compression.sparsity_config import ( SparsityConfigMetadata, ) +from llmcompressor.transformers.utils import DEFAULT_RECIPE_NAME from llmcompressor.utils.fsdp.helpers import ( find_and_move_state_dicts_to_cpu, unwrap_and_export_model, @@ -189,7 +190,7 @@ def skip(*args, **kwargs): ) compressor.update_config(save_directory) - recipe_path = os.path.join(save_directory, "recipe.yaml") + recipe_path = os.path.join(save_directory, DEFAULT_RECIPE_NAME) session = active_session() if (recipe_yaml_str := session.get_serialized_recipe()) is not None: diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index 401a454cf..b53705b9b 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -3,75 +3,21 @@ huggingface/transformers flows """ -import inspect import os -from collections import OrderedDict -from contextlib import suppress -from enum import Enum -from pathlib import Path -from typing import Iterable, List, Optional -from typing import OrderedDict as OrderedDictType -from typing import Tuple, Union +from typing import TYPE_CHECKING, Optional -import requests -import torch -import transformers -from huggingface_hub import HUGGINGFACE_CO_URL_HOME, HfFileSystem, hf_hub_download from loguru import logger -from transformers import AutoConfig from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import PaddingStrategy -from llmcompressor.utils.fsdp.context import main_process_first_context +if TYPE_CHECKING: + from llmcompressor.transformers import ModelArguments, TrainingArguments __all__ = [ - "RECIPE_NAME", + "DEFAULT_RECIPE_NAME", "detect_last_checkpoint", - "TaskNames", - "resolve_sequence_length", - "ALL_TASK_NAMES", - "create_fake_dataloader", - "POSSIBLE_TOKENIZER_FILES", - "download_repo_from_huggingface_hub", - "download_model_directory", ] - -class TaskNames(Enum): - mlm = {"masked-language-modeling", "mlm"} - qa = {"question-answering", "qa"} - token_classification = {"token-classification", "ner"} - text_classification = { - "text-classification", - "sentiment-analysis", - "sequence-classification", - "glue", - } - text_generation = {"text-generation"} - - -ALL_TASK_NAMES = list(set.union(*[task_names.value for task_names in TaskNames])) -RECIPE_NAME = "recipe.yaml" - -MANDATORY_DEPLOYMENT_FILES = { - "tokenizer_config.json", - "config.json", -} -OPTIONAL_DEPLOYMENT_FILES = {"tokenizer.json", "tokenizer.model"} -NLG_MANDATORY_DEPLOYMENT_FILES = {"special_tokens_map.json"} -NLG_OPTIONAL_DEPLOYMENT_FILES = { - "vocab.json", - "merges.txt", -} -POSSIBLE_TOKENIZER_FILES = { - "vocab.json", - "merges.txt", - "tokenizer.json", - "tokenizer.model", - "special_tokens_map.json", - "tokenizer_config.json", -} -RELEVANT_HF_SUFFIXES = ["json", "md", "bin", "safetensors", "yaml", "yml", "py"] +DEFAULT_RECIPE_NAME = "recipe.yaml" def detect_last_checkpoint( @@ -108,385 +54,3 @@ def detect_last_checkpoint( ) return last_checkpoint - - -def resolve_sequence_length(config: AutoConfig) -> int: - """ - Resolve the sequence length from the config - - :param config: the config to resolve the sequence length from - :return: the sequence length - """ - if hasattr(config, "max_position_embeddings"): - sequence_length = config.max_position_embeddings - - elif hasattr(config, "max_seq_len"): - sequence_length = config.max_seq_len - else: - raise ValueError( - "Could not infer a default sequence length " - "from the HF transformers config. Please specify " - "the sequence length with --sequence_length" - ) - logger.debug( - f"Using default sequence length of {sequence_length} " - "(inferred from HF transformers config) " - ) - return sequence_length - - -def resolve_recipe( - model_path: Union[str, Path], - recipe: Union[str, Path, None] = None, -) -> Union[str, None]: - """ - Resolve the recipe to apply to the model. - :param recipe: the recipe to apply to the model. - It can be one of the following: - - None - This means that we are not either not applying - any recipe and allowing the model to potentially - infer the appropriate pre-existing recipe - from the model_path - - a path to the recipe file - This can be a string or Path object pointing - to a recipe file. If the specified recipe file - is different from the potential pre-existing - recipe for that model (stored in the model_path), - the function will raise an warning - - name of the recipe file (e.g. "recipe.yaml") - Recipe file name specific is assumed to be stored - in the model_path - - a string containing the recipe - Needs to adhere to the SparseML recipe format - - :param model_path: the path to the model to load. - It can be one of the following: - - a path to the model directory - - a path to the model file - - Hugging face model id - - :return: the resolved recipe - """ - - if recipe is None: - return infer_recipe_from_model_path(model_path) - - elif os.path.isfile(recipe): - # recipe is a path to a recipe file - return resolve_recipe_file(recipe, model_path) - - elif os.path.isfile(os.path.join(model_path, recipe)): - # recipe is a name of a recipe file - recipe = os.path.join(model_path, recipe) - return resolve_recipe_file(recipe, model_path) - - elif isinstance(recipe, str): - # recipe is a string containing the recipe - logger.debug( - "Applying the recipe string directly to the model, without " - "checking for a potential existing recipe in the model_path." - ) - return recipe - - logger.info( - "No recipe requested and no default recipe " - f"found in {model_path}. Skipping recipe resolution." - ) - return None - - -def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]: - """ - Infer the recipe from the model_path. - :param model_path: the path to the model to load. - It can be one of the following: - - a path to the model directory - - a path to the model file - - Hugging face model id - :return the path to the recipe file if found, None otherwise - """ - model_path = model_path.as_posix() if isinstance(model_path, Path) else model_path - - if os.path.isdir(model_path) or os.path.isfile(model_path): - # model_path is a local path to the model directory or model file - # attempting to find the recipe in the model_directory - model_path = ( - os.path.dirname(model_path) if os.path.isfile(model_path) else model_path - ) - recipe = os.path.join(model_path, RECIPE_NAME) - if os.path.isfile(recipe): - logger.info(f"Found recipe in the model_path: {recipe}") - return recipe - logger.debug(f"No recipe found in the model_path: {model_path}") - return None - - recipe = recipe_from_huggingface_model_id(model_path)[0] - - if recipe is None: - logger.info("Failed to infer the recipe from the model_path") - return recipe - - -def recipe_from_huggingface_model_id( - model_path: str, recipe_name: str = RECIPE_NAME -) -> Tuple[Optional[str], bool]: - """ - Attempts to download the recipe from the huggingface model id. - - :param model_path: Assumed to be the huggingface model id. - If it is not, this function will return None. - :param recipe_name: The name of the recipe file to download. - Defaults to RECIPE_NAME. - :return: tuple: - - the path to the recipe file if found, None otherwise - - True if model_path is a valid huggingface model id, False otherwise - """ - model_id = os.path.join(HUGGINGFACE_CO_URL_HOME, model_path) - request = requests.get(model_id) - if not request.status_code == 200: - logger.debug( - "model_path is not a valid huggingface model id. " - "Skipping recipe resolution." - ) - return None, False - - logger.info( - "model_path is a huggingface model id. " - "Attempting to download recipe from " - f"{HUGGINGFACE_CO_URL_HOME}" - ) - try: - recipe = hf_hub_download(repo_id=model_path, filename=recipe_name) - logger.info(f"Found recipe: {recipe_name} for model id: {model_path}.") - except Exception as e: - logger.info( - f"Unable to to find recipe {recipe_name} " - f"for model id: {model_path}: {e}. " - "Skipping recipe resolution." - ) - recipe = None - return recipe, True - - -def resolve_recipe_file( - requested_recipe: Union[str, Path], model_path: Union[str, Path] -) -> Union[str, Path, None]: - """ - Given the requested recipe and the model_path, return the path to the recipe file. - - :param requested_recipe. Is a full path to the recipe file - :param model_path: the path to the model to load. - It can be one of the following: - - a path to the model directory - - a path to the model file - - Hugging face model id - :return the path to the recipe file if found, None otherwise - """ - # preprocess arguments so that they are all strings - requested_recipe = ( - requested_recipe.as_posix() - if isinstance(requested_recipe, Path) - else requested_recipe - ) - model_path = model_path.as_posix() if isinstance(model_path, Path) else model_path - model_path = ( - os.path.dirname(model_path) if os.path.isfile(model_path) else model_path - ) - - if not os.path.isdir(model_path): - default_recipe, model_exists = recipe_from_huggingface_model_id(model_path) - if not model_exists: - raise ValueError(f"Unrecognized model_path: {model_path}") - - if not default_recipe == requested_recipe and default_recipe is not None: - logger.warning( - f"Attempting to apply recipe: {requested_recipe} " - f"to the model at: {model_path}, " - f"but the model already has a recipe: {default_recipe}. " - f"Using {requested_recipe} instead." - ) - return requested_recipe - - # pathway for model_path that is a directory - default_recipe = os.path.join(model_path, RECIPE_NAME) - default_recipe_exists = os.path.isfile(default_recipe) - default_and_request_recipes_identical = os.path.samefile( - default_recipe, requested_recipe - ) - - if ( - default_recipe_exists - and requested_recipe - and not default_and_request_recipes_identical - ): - logger.warning( - f"Attempting to apply recipe: {requested_recipe} " - f"to the model located in {model_path}, " - f"but the model already has a recipe stored as {default_recipe}. " - f"Using {requested_recipe} instead." - ) - - elif not default_recipe_exists and requested_recipe: - logger.warning( - f"Attempting to apply {requested_recipe} " - f"to the model located in {model_path}." - "However, it is expected that the model " - f"has its target recipe stored as {default_recipe}." - "Applying any recipe before the target recipe may " - "result in unexpected behavior." - f"Applying {requested_recipe} nevertheless." - ) - - elif default_recipe_exists: - logger.info(f"Using the default recipe: {requested_recipe}") - - return requested_recipe - - -def create_fake_dataloader( - model: torch.nn.Module, - tokenizer: transformers.AutoTokenizer, - num_samples: int, -) -> Tuple[Iterable[OrderedDictType[str, torch.Tensor]], List[str]]: - """ - Creates fake transformers dataloader for the model, based on the model's - forward signature. - - :param model: The model to create the dataloader for - :param tokenizer: The tokenizer to use for the dataloader - :param num_samples: The number of fake samples in the dataloader - :return: The data loader (iterable) and the input names for the model - """ - - forward_args_spec = inspect.getfullargspec(model.__class__.forward) - inputs = tokenizer( - "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value - ).data - fake_inputs = OrderedDict( - [ - (input_key, inputs[input_key][0].reshape(1, -1)) - for input_key in forward_args_spec.args - if input_key in inputs - ] - ) - data_loader = (fake_inputs for _ in range(num_samples)) - input_names = list(fake_inputs.keys()) - return data_loader, input_names - - -def fetch_recipe_path(target: str): - """ - Fetches the recipe path for the given target. - This method will also download the recipe if it is not - already downloaded. - - Takes care of three scenarios: - 1. target is a local path to a model directory - (looks for recipe.yaml in the directory) - 2. target is a HuggingFace stub (downloads and - returns the path to the default recipe) - - :param target: The target to fetch the recipe path for - can be a local path or HuggingFace stub - :return: The path to the recipe for the target - """ - DEFAULT_RECIPE_NAME = "recipe.yaml" - if Path(target).exists(): - # target is a local path - potential_recipe_path = Path(target) / DEFAULT_RECIPE_NAME - return str(potential_recipe_path) if potential_recipe_path.exists() else None - - # Recipe must be downloaded - - recipe_path = None - - # target is a HuggingFace stub - with suppress(Exception): - # suppress any errors if the recipe is not found on HuggingFace - recipe_path = hf_hub_download(repo_id=target, filename=DEFAULT_RECIPE_NAME) - - return recipe_path - - -def download_repo_from_huggingface_hub(repo_id, **kwargs): - """ - Download relevant model files from the Hugging Face Hub - using the huggingface_hub.hf_hub_download function - - Note(s): - - Does not download the entire repo, only the relevant files - for the model, such as the model weights, tokenizer files, etc. - - Does not re-download files that already exist locally, unless - the force_download flag is set to True - - :pre-condition: the repo_id must be a valid Hugging Face Hub repo id - :param repo_id: the repo id to download - :param kwargs: additional keyword arguments to pass to hf_hub_download - """ - hf_filesystem = HfFileSystem() - files = hf_filesystem.ls(repo_id) - - if not files: - raise ValueError(f"Could not find any files in HF repo {repo_id}") - - # All file(s) from hf_filesystem have "name" key - # Extract the file names from the files - relevant_file_names = ( - Path(file["name"]).name - for file in files - if any(file["name"].endswith(suffix) for suffix in RELEVANT_HF_SUFFIXES) - ) - - hub_kwargs_names = ( - "subfolder", - "repo_type", - "revision", - "library_name", - "library_version", - "cache_dir", - "local_dir", - "local_dir_use_symlinks", - "user_agent", - "force_download", - "force_filename", - "proxies", - "etag_timeout", - "resume_download", - "token", - "local_files_only", - "headers", - "legacy_cache_layout", - "endpoint", - ) - hub_kwargs = {name: kwargs[name] for name in hub_kwargs_names if name in kwargs} - - for file_name in relevant_file_names: - last_file = hf_hub_download(repo_id=repo_id, filename=file_name, **hub_kwargs) - - # parent directory of the last file is the model directory - return str(Path(last_file).parent.resolve().absolute()) - - -def download_model_directory(pretrained_model_name_or_path: str, **kwargs): - """ - Download the model directory from the HF hub if the model is not found locally - - :param pretrained_model_name_or_path: the name of or path to the model to load - can be a HuggingFace model stub - :param kwargs: additional keyword arguments to pass to the download function - :return: the path to the downloaded model directory - """ - pretrained_model_path: Path = Path(pretrained_model_name_or_path) - - if pretrained_model_path.exists(): - logger.debug( - "Model directory already exists locally.", - ) - return pretrained_model_name_or_path - - with main_process_first_context(): - logger.debug("Downloading model from HuggingFace Hub.") - return download_repo_from_huggingface_hub( - repo_id=pretrained_model_name_or_path, **kwargs - ) From 7e516c143fea72be5db0fa06e1d2d5bae6ea1cc4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 00:38:13 +0000 Subject: [PATCH 119/285] use in model_load Signed-off-by: Kyle Sayers --- src/llmcompressor/pytorch/model_load/helpers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 11a924f1d..180c559af 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -9,6 +9,7 @@ from llmcompressor.core import active_session, create_session, pre_initialize_structure from llmcompressor.pytorch.utils import ModuleSparsificationInfo +from llmcompressor.transformers import DEFAULT_RECIPE_NAME COMPLETED_STAGES_FILENAME = "completed_stages.json" @@ -24,8 +25,6 @@ "save_completed_stages", ] -RECIPE_FILE_NAME = "recipe.yaml" - def log_model_load( model: Module, model_name_or_path: str, model_type: str, delayed_load: bool @@ -116,7 +115,7 @@ def save_model_and_recipe( logger.info("Saving output to {}".format(os.path.abspath(save_path))) - recipe_path = os.path.join(save_path, RECIPE_FILE_NAME) + recipe_path = os.path.join(save_path, DEFAULT_RECIPE_NAME) session = active_session() recipe_yaml_str = session.get_serialized_recipe() with open(recipe_path, "w") as fp: From 9e33641b1ca660d61d1d857a4d4486152184e4df Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 16:39:18 +0000 Subject: [PATCH 120/285] remove use of RECIPE FILE NAME Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/session_mixin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index e3a9c4d84..498ff4a40 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -24,8 +24,9 @@ from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( KDModelWrapper, ) -from llmcompressor.pytorch.model_load.helpers import RECIPE_FILE_NAME, get_session_model +from llmcompressor.pytorch.model_load.helpers import get_session_model from llmcompressor.pytorch.utils import ModuleSparsificationInfo +from llmcompressor.transformers import DEFAULT_RECIPE_NAME from llmcompressor.transformers.finetune.callbacks import ( DisableHalfPrecisionCallback, TrainingLoopCallbacks, @@ -495,7 +496,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): if self.accelerator.is_main_process: # save recipe, will contain modifiers from the model's original recipe as # well as those added from self.recipe - recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME) + recipe_path = os.path.join(output_dir, DEFAULT_RECIPE_NAME) session = active_session() recipe_yaml_str = session.get_serialized_recipe() with open(recipe_path, "w") as fp: From 58c0fba3d75f8e25bd19fd6cbbc8823c2eaeb5c3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 16:56:00 +0000 Subject: [PATCH 121/285] rename to RECIPE_FILE_NAME, avoid circular import Signed-off-by: Kyle Sayers --- src/llmcompressor/pytorch/model_load/helpers.py | 5 +++-- src/llmcompressor/transformers/finetune/session_mixin.py | 4 ++-- .../transformers/sparsification/compressed_tensors_utils.py | 4 ++-- src/llmcompressor/transformers/utils/helpers.py | 4 ++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 180c559af..3db9be173 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -9,7 +9,6 @@ from llmcompressor.core import active_session, create_session, pre_initialize_structure from llmcompressor.pytorch.utils import ModuleSparsificationInfo -from llmcompressor.transformers import DEFAULT_RECIPE_NAME COMPLETED_STAGES_FILENAME = "completed_stages.json" @@ -105,6 +104,8 @@ def save_model_and_recipe( :param save_safetensors: whether to save as safetensors or pickle (bin) :param save_compressed: whether to compress sparse weights on disk """ + # avoid circular import + from llmcompressor.transformers.utils.helpers import RECIPE_FILE_NAME model.save_pretrained( save_path, save_compressed=save_compressed, safe_serialization=save_safetensors @@ -115,7 +116,7 @@ def save_model_and_recipe( logger.info("Saving output to {}".format(os.path.abspath(save_path))) - recipe_path = os.path.join(save_path, DEFAULT_RECIPE_NAME) + recipe_path = os.path.join(save_path, RECIPE_FILE_NAME) session = active_session() recipe_yaml_str = session.get_serialized_recipe() with open(recipe_path, "w") as fp: diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 498ff4a40..b1ac57b95 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -26,7 +26,7 @@ ) from llmcompressor.pytorch.model_load.helpers import get_session_model from llmcompressor.pytorch.utils import ModuleSparsificationInfo -from llmcompressor.transformers import DEFAULT_RECIPE_NAME +from llmcompressor.transformers import RECIPE_FILE_NAME from llmcompressor.transformers.finetune.callbacks import ( DisableHalfPrecisionCallback, TrainingLoopCallbacks, @@ -496,7 +496,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): if self.accelerator.is_main_process: # save recipe, will contain modifiers from the model's original recipe as # well as those added from self.recipe - recipe_path = os.path.join(output_dir, DEFAULT_RECIPE_NAME) + recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME) session = active_session() recipe_yaml_str = session.get_serialized_recipe() with open(recipe_path, "w") as fp: diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 88822f69e..759098894 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -24,7 +24,7 @@ from llmcompressor.transformers.compression.sparsity_config import ( SparsityConfigMetadata, ) -from llmcompressor.transformers.utils import DEFAULT_RECIPE_NAME +from llmcompressor.transformers.utils import RECIPE_FILE_NAME from llmcompressor.utils.fsdp.helpers import ( find_and_move_state_dicts_to_cpu, unwrap_and_export_model, @@ -190,7 +190,7 @@ def skip(*args, **kwargs): ) compressor.update_config(save_directory) - recipe_path = os.path.join(save_directory, DEFAULT_RECIPE_NAME) + recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) session = active_session() if (recipe_yaml_str := session.get_serialized_recipe()) is not None: diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index b53705b9b..a93111a8d 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -13,11 +13,11 @@ from llmcompressor.transformers import ModelArguments, TrainingArguments __all__ = [ - "DEFAULT_RECIPE_NAME", + "RECIPE_FILE_NAME", "detect_last_checkpoint", ] -DEFAULT_RECIPE_NAME = "recipe.yaml" +RECIPE_FILE_NAME = "recipe.yaml" def detect_last_checkpoint( From 8d13013f3e65352511753501271c0e3fcccae536 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 20:02:59 +0000 Subject: [PATCH 122/285] image dataset collation --- qwen.py | 114 ++++++++++++++++++ .../transformers/finetune/data/__init__.py | 1 + .../finetune/data/data_helpers.py | 13 ++ .../transformers/finetune/data/flickr_30k.py | 61 ++++++++++ .../finetune/data/ultrachat_200k.py | 27 +++-- 5 files changed, 205 insertions(+), 11 deletions(-) create mode 100644 qwen.py create mode 100644 src/llmcompressor/transformers/finetune/data/flickr_30k.py diff --git a/qwen.py b/qwen.py new file mode 100644 index 000000000..7898c25d9 --- /dev/null +++ b/qwen.py @@ -0,0 +1,114 @@ +import torch +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +import os + +# Load model. +model_id = "Qwen/Qwen2-VL-2B-Instruct" +model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +print("Loading dataset") +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:512]" + +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +print("Preprocessing samples") +def preprocess(example): + """ + Preprocesses a single example from the dataset. + """ + # Example messages structure + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ] + return { + "text": [processor.apply_chat_template( + messages, + add_generation_prompt=True, + )], + } + +ds = ds.map(preprocess, remove_columns=["caption", "sentids", "img_id", "filename"]) + + +# Tokenize inputs. +def tokenize(sample): + image = sample.pop("image") + return processor( + **sample, + images=[image], + add_special_tokens=False, + return_tensors="pt" + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +def collate_fn(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), # torch.Size([14308, 1176]) + "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), + } + + +from llmcompressor.pytorch.utils import tensors_to_device +from llmcompressor.transformers.finetune.data.data_helpers import format_calibration_data +one_sample = next(iter(format_calibration_data(ds, collate_fn=collate_fn))) +batch = tensors_to_device(one_sample, "cuda:0") +model(**batch) + +print("Setting up quantization params") +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] +#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] + +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), +] + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, +) + +#processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index a047c3b09..083ccf0e4 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -11,3 +11,4 @@ from .ptb import PtbDataset from .ultrachat_200k import UltraChatDataset from .wikitext import WikiTextDataset +from .flickr_30k import Flickr30K \ No newline at end of file diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 1b0b589f9..fae0c543c 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -37,6 +37,19 @@ def format_calibration_data( :param accelerator: optional accelerator for if preparing in FSDP mode :return: list of trimmed calibration data tensors """ + if True: + def collate_fn(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), # torch.Size([14308, 1176]) + "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), + } + + + + safe_calibration_samples = len(tokenized_dataset) if num_calibration_samples is not None: safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py new file mode 100644 index 000000000..84e45886d --- /dev/null +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -0,0 +1,61 @@ +from copy import deepcopy +from typing import TYPE_CHECKING + +from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.utils import Processor + +if TYPE_CHECKING: + from llmcompressor.transformers import DataTrainingArguments as DataArgs + + +@TextGenerationDataset.register(name="flickr", alias="flickr30k") +class Flickr30K(TextGenerationDataset): + """ + :param data_args: configuration settings for dataset loading + :param split: split from dataset to load, for instance `test` or `train[:5%]` + :param processor: processor or tokenizer to use on dataset + """ + + DEFAULT_CHAT_TEMPLATE = ( + "{% for message in messages %}\n" + "{% if message['role'] == 'user' %}\n" + "{{ '<|user|>\n' + message['content'] + eos_token }}\n" + "{% elif message['role'] == 'system' %}\n" + "{{ '<|system|>\n' + message['content'] + eos_token }}\n" + "{% elif message['role'] == 'assistant' %}\n" + "{{ '<|assistant|>\n' + message['content'] + eos_token }}\n" + "{% endif %}\n" + "{% if loop.last and add_generation_prompt %}\n" + "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" + ) + + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + data_args = deepcopy(data_args) + data_args.dataset = "lmms-lab/flickr30k" + + super().__init__(data_args=data_args, split=split, processor=processor) + + if ( + self.tokenizer is not None and + getattr(self.tokenizer, "chat_template", None) is None + ): + # note that since tokenizer is a member of processor, + # this change affects processor.apply_chat_template + self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE + + def dataset_template(self, sample): + if self.processor is None: + raise ValueError("TODO") + + breakpoint() + + messages = sample["messages"] + if messages[0]["role"] != "system": + messages.insert(0, {"role": "system", "content": ""}) + + return { + "text": self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ), + "image": sample["image"], + } \ No newline at end of file diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index c9d60bc84..acc6498d8 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -3,6 +3,7 @@ from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.utils import Processor +from llmcompressor.utils.helpers import getattr_chain if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs @@ -39,22 +40,26 @@ def __init__(self, data_args: "DataArgs", split: str, processor: Processor): if split in ["train", "test"]: split += "_sft" - super().__init__( - data_args=data_args, - split=split, - processor=processor, - ) + super().__init__(data_args=data_args, split=split, processor=processor) - if getattr(self.tokenizer, "chat_template", None) is None: + if ( + self.tokenizer is not None and + getattr(self.tokenizer, "chat_template", None) is None + ): + # note that since tokenizer is a member of processor, + # this change affects processor.apply_chat_template self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE def dataset_template(self, sample): - messages = sample["messages"] + if self.processor is None: + raise ValueError("TODO") + messages = sample["messages"] if messages[0]["role"] != "system": messages.insert(0, {"role": "system", "content": ""}) - text = self.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=False - ) - return {"text": text} + return { + "text": self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + } From 163ee8fb4da1262ee2b4d2894b8d896e855466fe Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 20:40:55 +0000 Subject: [PATCH 123/285] cleanup, do not handle case where processor is None Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/__init__.py | 2 +- src/llmcompressor/transformers/finetune/data/base.py | 9 +-------- .../transformers/finetune/data/data_helpers.py | 10 +++++----- .../transformers/finetune/data/flickr_30k.py | 6 +++--- .../transformers/finetune/data/ultrachat_200k.py | 5 ++--- .../transformers/finetune/text_generation.py | 6 +++++- src/llmcompressor/transformers/utils/helpers.py | 3 --- 7 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index 083ccf0e4..ddf0b2364 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -6,9 +6,9 @@ from .custom import CustomDataset from .data_args import DataTrainingArguments from .evolcodealpaca import EvolCodeAlpacaDataset +from .flickr_30k import Flickr30K from .gsm8k import GSM8KDataset from .open_platypus import OpenPlatypusDataset from .ptb import PtbDataset from .ultrachat_200k import UltraChatDataset from .wikitext import WikiTextDataset -from .flickr_30k import Flickr30K \ No newline at end of file diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index b1d822eeb..b4b06b3b8 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -31,13 +31,6 @@ class TextGenerationDataset(RegistryMixin): # used to mask out the prompt so prompt tokens do not contribute to training loss PROMPT_KEY = "prompt" - # TODO: not sure how to handle the prompt stuff best. Specifically w.r.t. - """ - dataset = self.processor(**dataset) - - if dataset includes the PROMPT_KEY - """ - def __init__( self, data_args: DataTrainingArguments, @@ -102,7 +95,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: # rename and remove columns match processor kwargs dataset = self.rename_columns(dataset) - if self.processor is not None and "input_ids" not in dataset.column_names: + if "input_ids" not in dataset.column_names: # tokenize/ process dataset = self.map( dataset, diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index fae0c543c..d7dff5375 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -37,19 +37,19 @@ def format_calibration_data( :param accelerator: optional accelerator for if preparing in FSDP mode :return: list of trimmed calibration data tensors """ - if True: + if False: + def collate_fn(batch): assert len(batch) == 1 return { "input_ids": torch.LongTensor(batch[0]["input_ids"]), "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"]), # torch.Size([14308, 1176]) + "pixel_values": torch.tensor( + batch[0]["pixel_values"] + ), # torch.Size([14308, 1176]) "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), } - - - safe_calibration_samples = len(tokenized_dataset) if num_calibration_samples is not None: safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 84e45886d..8802ccd5a 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -36,8 +36,8 @@ def __init__(self, data_args: "DataArgs", split: str, processor: Processor): super().__init__(data_args=data_args, split=split, processor=processor) if ( - self.tokenizer is not None and - getattr(self.tokenizer, "chat_template", None) is None + self.tokenizer is not None + and getattr(self.tokenizer, "chat_template", None) is None ): # note that since tokenizer is a member of processor, # this change affects processor.apply_chat_template @@ -58,4 +58,4 @@ def dataset_template(self, sample): messages, tokenize=False, add_generation_prompt=True ), "image": sample["image"], - } \ No newline at end of file + } diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index acc6498d8..3e5a30b7b 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -3,7 +3,6 @@ from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.utils import Processor -from llmcompressor.utils.helpers import getattr_chain if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs @@ -43,8 +42,8 @@ def __init__(self, data_args: "DataArgs", split: str, processor: Processor): super().__init__(data_args=data_args, split=split, processor=processor) if ( - self.tokenizer is not None and - getattr(self.tokenizer, "chat_template", None) is None + self.tokenizer is not None + and getattr(self.tokenizer, "chat_template", None) is None ): # note that since tokenizer is a member of processor, # this change affects processor.apply_chat_template diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 1d7abe3c0..bbddc53e5 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -27,6 +27,7 @@ AutoProcessor, DefaultDataCollator, HfArgumentParser, + PreTrainedModel, set_seed, ) @@ -52,6 +53,7 @@ get_shared_processor_src, ) from llmcompressor.transformers.utils.helpers import detect_last_checkpoint +from llmcompressor.utils import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model @@ -226,7 +228,9 @@ def initialize_model_from_path( return teacher, model_path, model -def initialize_processor_from_path(model_args, model, teacher): +def initialize_processor_from_path( + model_args: ModelArguments, model: PreTrainedModel, teacher: PreTrainedModel +) -> Processor: processor_src = model_args.processor processor_src = processor_src or get_shared_processor_src(model, teacher) processor = AutoProcessor.from_pretrained( diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index cb3906812..e2ad2bef1 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -12,9 +12,6 @@ if TYPE_CHECKING: from llmcompressor.transformers import ModelArguments, TrainingArguments -if TYPE_CHECKING: - from llmcompressor.transformers import ModelArguments, TrainingArguments - __all__ = [ "RECIPE_FILE_NAME", From 1180b3417c4884de05b297fa8f5e258c540eebef Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 20:42:11 +0000 Subject: [PATCH 124/285] remove qa ignore Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/utils/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index a93111a8d..1263bb004 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -21,8 +21,8 @@ def detect_last_checkpoint( - training_args: "TrainingArguments", # noqa 821 - model_args: Optional["ModelArguments"] = None, # noqa 821 + training_args: "TrainingArguments", + model_args: Optional["ModelArguments"] = None, ): last_checkpoint = None if ( From c431958955696ee3234ce35c865724b3b75ea017 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 20:56:41 +0000 Subject: [PATCH 125/285] add documentation Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index b4b06b3b8..11b7bbfa9 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -21,7 +21,13 @@ class TextGenerationDataset(RegistryMixin): """ - Base class for text datasets, handles tokenization and dataset splits + Base class for text datasets. Applies the following transformations to a dataset + in order to prepare the dataset to be loaded by a dataloader + + 1. Load dataset from huggingface or local cache + 2. Preprocess dataset according to preprocess function or chat/dataset template + 3. Tokenize dataset using model tokenizer/processor + 4. Apply post processing such as grouping text and/or adding labels for finetuning :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` From b48d55d452ea228155f5d9727acb374ea09d174e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 21:31:03 +0000 Subject: [PATCH 126/285] add data collator arg Signed-off-by: Kyle Sayers --- qwen.py | 10 ++-------- .../transformers/finetune/data/data_args.py | 12 +++++++++++- .../transformers/finetune/data/data_helpers.py | 13 ------------- src/llmcompressor/transformers/finetune/runner.py | 1 + .../transformers/finetune/text_generation.py | 4 +--- 5 files changed, 15 insertions(+), 25 deletions(-) diff --git a/qwen.py b/qwen.py index 7898c25d9..e087e9266 100644 --- a/qwen.py +++ b/qwen.py @@ -60,7 +60,7 @@ def tokenize(sample): ds = ds.map(tokenize, remove_columns=ds.column_names) -def collate_fn(batch): +def data_collator(batch): assert len(batch) == 1 return { "input_ids": torch.LongTensor(batch[0]["input_ids"]), @@ -69,13 +69,6 @@ def collate_fn(batch): "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), } - -from llmcompressor.pytorch.utils import tensors_to_device -from llmcompressor.transformers.finetune.data.data_helpers import format_calibration_data -one_sample = next(iter(format_calibration_data(ds, collate_fn=collate_fn))) -batch = tensors_to_device(one_sample, "cuda:0") -model(**batch) - print("Setting up quantization params") # Configure the quantization algorithm and scheme. # In this case, we: @@ -102,6 +95,7 @@ def collate_fn(batch): num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, output_dir=save_path, + data_collator=data_collator, ) #processor.save_pretrained(save_path) diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index 9623d413a..06aa685aa 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -1,5 +1,7 @@ from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union + +from transformers import DefaultDataCollator @dataclass @@ -54,6 +56,14 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): }, ) + data_collator: Callable[[Any], Any] = field( + default=DefaultDataCollator(), + metadata={ + "help": "For custom datasets only. The function to used to form a batch " + "from the dataset" + }, + ) + @dataclass class DataTrainingArguments(CustomDataTrainingArguments): diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index d7dff5375..1b0b589f9 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -37,19 +37,6 @@ def format_calibration_data( :param accelerator: optional accelerator for if preparing in FSDP mode :return: list of trimmed calibration data tensors """ - if False: - - def collate_fn(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor( - batch[0]["pixel_values"] - ), # torch.Size([14308, 1176]) - "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), - } - safe_calibration_samples = len(tokenized_dataset) if num_calibration_samples is not None: safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 6a89e8a7a..368da499a 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -138,6 +138,7 @@ def one_shot(self, stage: Optional[str] = None): tokenized_dataset=self.get_dataset_split("calibration"), num_calibration_samples=self._data_args.num_calibration_samples, do_shuffle=self._data_args.shuffle_calibration_samples, + collate_fn=self._data_args.data_collator, accelerator=self.trainer.accelerator, ) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index bbddc53e5..ffea2df9d 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -25,7 +25,6 @@ AutoConfig, AutoModelForCausalLM, AutoProcessor, - DefaultDataCollator, HfArgumentParser, PreTrainedModel, set_seed, @@ -339,7 +338,6 @@ def main( calib_dataset = stage_runner.get_dataset_split("calibration") # Initialize our Trainer - data_collator = DefaultDataCollator() trainer = Trainer( model_init=get_session_model, teacher=teacher, @@ -350,7 +348,7 @@ def main( train_dataset=train_dataset or calib_dataset, eval_dataset=eval_dataset, processing_class=processor, - data_collator=data_collator, + data_collator=data_args.data_collator, ) # wrap model.save_pretrained From 0ed5c2c5c212890141942285dff914333bbb9054 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 3 Dec 2024 21:39:08 +0000 Subject: [PATCH 127/285] use default factor Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/data_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index 06aa685aa..c2340549f 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -57,7 +57,7 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): ) data_collator: Callable[[Any], Any] = field( - default=DefaultDataCollator(), + default_factory=lambda: DefaultDataCollator(), metadata={ "help": "For custom datasets only. The function to used to form a batch " "from the dataset" From 41dd463edf77031167b0cf0c4a660658cf037e39 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 17:51:26 +0000 Subject: [PATCH 128/285] wip mllama --- mllama.py | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ qwen.py | 3 +- 2 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 mllama.py diff --git a/mllama.py b/mllama.py new file mode 100644 index 000000000..dff0e28e8 --- /dev/null +++ b/mllama.py @@ -0,0 +1,116 @@ +import torch +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +import os + +# Load model. +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +#model = MllamaForConditionalGeneration.from_pretrained(model_id) +#model_id = "mgoin/pixtral-12b" +#model_id = "Qwen/Qwen2-VL-2B-Instruct" + +#model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") +#model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +#model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype="auto") +model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +print("Loading dataset") +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:512]" + +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +print("Preprocessing samples") +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + +ds = ds.map(preprocess, remove_columns=["caption", "sentids", "img_id", "filename"]) + + +# Tokenize inputs. +def tokenize(sample): + image = sample.pop("image") + return processor( + **sample, + images=[image], + add_special_tokens=False, + return_tensors="pt" + ) + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +def data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), + "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), + "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), + } + + + +print("Setting up quantization params") +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] +#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] + +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), +] + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, + data_collator=data_collator, +) + +#processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/qwen.py b/qwen.py index e087e9266..80f9fd2ee 100644 --- a/qwen.py +++ b/qwen.py @@ -57,7 +57,6 @@ def tokenize(sample): return_tensors="pt" ) - ds = ds.map(tokenize, remove_columns=ds.column_names) def data_collator(batch): @@ -65,7 +64,7 @@ def data_collator(batch): return { "input_ids": torch.LongTensor(batch[0]["input_ids"]), "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"]), # torch.Size([14308, 1176]) + "pixel_values": torch.tensor(batch[0]["pixel_values"]), "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), } From 8527e0e757d22b05a2224e32f8c2662cafb4324e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 17:55:52 +0000 Subject: [PATCH 129/285] cleanup --- mllama.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/mllama.py b/mllama.py index dff0e28e8..30e27e12f 100644 --- a/mllama.py +++ b/mllama.py @@ -8,13 +8,6 @@ # Load model. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -#model = MllamaForConditionalGeneration.from_pretrained(model_id) -#model_id = "mgoin/pixtral-12b" -#model_id = "Qwen/Qwen2-VL-2B-Instruct" - -#model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", _attn_implementation="eager") -#model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") -#model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype="auto") model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) @@ -78,12 +71,7 @@ def data_collator(batch): print("Setting up quantization params") -# Configure the quantization algorithm and scheme. -# In this case, we: -# * quantize the weights to fp8 with per channel via ptq -# * quantize the activations to fp8 with dynamic per token -#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] -#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +# CHANGE THIS IF YOU WANT TO QUANTIZE THE VISION TOWER ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] recipe = [ From 0a8a03f7d730b910c4e0acc63b9e70f4513ef1ca Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 18:23:56 +0000 Subject: [PATCH 130/285] merge-implement hessian offloading --- qwen.py | 2 +- .../modifiers/quantization/gptq/base.py | 48 ++++++++++++------- .../peicewise.py | 0 3 files changed, 32 insertions(+), 18 deletions(-) rename src/llmcompressor/{data_pipelines => pipelines}/peicewise.py (100%) diff --git a/qwen.py b/qwen.py index 80f9fd2ee..f65125d0e 100644 --- a/qwen.py +++ b/qwen.py @@ -79,7 +79,7 @@ def data_collator(batch): recipe = [ # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES, offload_hessians=False), ] save_name = model_id.split("/")[1] + "-W8A8" diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 26434e040..b25fafb6b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,3 +1,4 @@ +import contextlib import warnings from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -103,14 +104,16 @@ class GPTQModifier(Modifier, HooksMixin): and activation 8 bit quantization on the Linear layers. """ + # gptq modifier arguments sequential_update: bool = True # DEPRECIATED update_size: int = 1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 quantize: Union[bool, Dict] = True + offload_hessians: bool = False - # arguments used for quant modifier + # arguments used for attached quant modifier config_groups: Optional[Dict[str, QuantizationScheme]] = None scheme: Optional[Union[str, Dict[str, Any]]] = None targets: Union[str, List[str], None] = None @@ -118,11 +121,10 @@ class GPTQModifier(Modifier, HooksMixin): num_calibration_steps: Optional[int] = None disable_quantization_observer_epoch: Optional[float] = None + # private variables _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() - _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr( - default_factory=lambda: {} - ) - _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=lambda: {}) + _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) + _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -250,26 +252,27 @@ def compress_module( :return: total loss from applying weight quantization to this module """ - # Assume that first argument is the input inp = args[0] quant_args = getattr_chain(module, "quantization_scheme.weights") - # TODO: attach as parameters to the module to allow them to be offloaded + # Initialize hessian if not present if module not in self._num_samples: self._hessians[module] = make_empty_hessian(module) self._num_samples[module] = 0 - self._hessians[module], self._num_samples[module] = accumulate_hessian( - inp, - type(module), - self._hessians[module], - self._num_samples[module], - ) + # Accumulate hessian with input with optional offloading + with self._maybe_offload_hessians(module): + self._hessians[module], self._num_samples[module] = accumulate_hessian( + inp, + type(module), + self._hessians[module], + self._num_samples[module], + ) + # After enough samples are accumulated, perform quantization if self._num_samples[module] >= self.update_size: - logger.info(f"Quantizing {name}...") - logger.info(f"Using {self._num_samples[module]} accumulated samples") + logger.info(f"Quantizing {name} using {self._num_samples[module]} samples") with align_module(module), CompressionLogger(module) as comp_logger: loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( module.weight.data, @@ -280,8 +283,7 @@ def compress_module( module_class=type(module), ) - module.weight -= module.weight - module.weight += quantized_weight + module.weight += quantized_weight - module.weight # Future: FSDP update_offload_parameter(module, "weight") update_offload_parameter(module, "weight_scale", scale) update_offload_parameter(module, "weight_zero_point", zero_point) @@ -293,6 +295,18 @@ def compress_module( comp_logger.set_loss(loss) + @contextlib.contextmanager + def _maybe_offload_hessians(self, module: torch.nn.Module): + if self.offload_hessians: + device = self._hessians[module].device + self._hessians[module] = self._hessians[module].to(device="cpu") + + yield + + if self.offload_hessians: + self._hessians[module] = self._hessians[module].to(device=device) + + def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, diff --git a/src/llmcompressor/data_pipelines/peicewise.py b/src/llmcompressor/pipelines/peicewise.py similarity index 100% rename from src/llmcompressor/data_pipelines/peicewise.py rename to src/llmcompressor/pipelines/peicewise.py From fc044e217bb6150aff55270efb9b7f5f1a55245c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 19:04:41 +0000 Subject: [PATCH 131/285] better concrete arg handling --- .../gptq/utils/partitioned_model.py | 70 +++++-------------- 1 file changed, 17 insertions(+), 53 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index 4605f11e2..c81c34c97 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -168,58 +168,6 @@ def trace_consumed_names(subgraphs: List[Dict[str, Any]]): assert False -def make_fused_concrete_args(root: torch.nn.Module, dummy_inputs: Dict[str, Any]): - sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root) - - concrete_args = {} - - for param in sig.parameters.values(): - if param.name in dummy_inputs: - continue - if param.default is inspect.Parameter.empty: - raise ValueError( - f"You need to specify a default value for the parameter {param.name}." - ) - - concrete_args.update( - { - p.name: p.default - for p in sig.parameters.values() - if (p.name not in dummy_inputs and p.name not in concrete_args) - } - ) - concrete_args.update(dummy_inputs) - - return concrete_args - - -def make_placeholders( - tracer, model: torch.nn.Module, graph: GraphModule, dummy_inputs: Dict[str, Any] -): - # TODO: this dictionary does not match tensors which have been deep copied - # in general it's pretty annoying, since tracer.create_args_for_root basically - # converts kwargs to args and therefore gets rid of any of the names. - - # maybe instead of caching by kwargs, we cache by arg tuples? Not sure - - # Note, maybe relevant: tracer.create_args_for_root converts kwargs to args using the forward function signature - - # TODO: assumes that all inputs are tensors - for input_name, input_value in dummy_inputs.items(): - for tensor_value, name in tracer.tensor_attrs.items(): - if torch.allclose(input_value, tensor_value): - nodes = graph.graph.find_nodes(op="get_attr", target=name) - assert len(nodes) == 1 - node = nodes[0] - node.target = input_name - node.name = input_name - node.op = "placeholder" - break - - else: - breakpoint() - - class PartitionedModel: def __init__(self): self.graph = None @@ -244,8 +192,24 @@ def is_leaf_module( return super().is_leaf_module(module, module_qualified_name) with HooksMixin.disable_hooks(), calibration_forward_context(self.model): + sig = inspect.signature(self.model.forward) + concrete_args = {} + for parameter in sig.parameters.values(): + if parameter.name in model.dummy_inputs: + continue + if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL: + value = list() + elif parameter.kind == inspect._ParameterKind.VAR_KEYWORD: + value = dict() + elif parameter.name == "use_cache": + value = False + else: + value = parameter.default + + concrete_args[parameter.name] = value + #self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) - self.graph: GraphModule = torch.fx.GraphModule(model, CustomTracer().trace(model, dummy_inputs=model.dummy_inputs, concrete_args={"use_cache": False})) + self.graph: GraphModule = torch.fx.GraphModule(model, CustomTracer().trace(model, dummy_inputs=model.dummy_inputs, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False)) self.graph.config = model.config self.graph.class_for_deserialization = model.__class__ self.graph.device = model.device From 45767127a5f255e61c0be8cd36e45f8fafc44cec Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 20:46:37 +0000 Subject: [PATCH 132/285] validate flickr Signed-off-by: Kyle Sayers --- examples/multimodal_vision/qwen_vl2.py | 89 +++++++++++++++ qwen.py | 108 ------------------ .../transformers/finetune/data/base.py | 11 +- .../finetune/data/data_helpers.py | 75 ++++++++---- .../transformers/finetune/data/flickr_30k.py | 26 +++-- src/llmcompressor/utils/typing.py | 4 + .../finetune/data/test_dataset_loading.py | 21 ++-- 7 files changed, 179 insertions(+), 155 deletions(-) create mode 100644 examples/multimodal_vision/qwen_vl2.py delete mode 100644 qwen.py diff --git a/examples/multimodal_vision/qwen_vl2.py b/examples/multimodal_vision/qwen_vl2.py new file mode 100644 index 000000000..c8cfb6e73 --- /dev/null +++ b/examples/multimodal_vision/qwen_vl2.py @@ -0,0 +1,89 @@ +import os + +import torch +from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot + +# Load model. +model_id = "Qwen/Qwen2-VL-2B-Instruct" +model = Qwen2VLForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = "test[:512]" +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + + +def data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor( + batch[0]["pixel_values"] + ), # torch.Size([14308, 1176]) + "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), + } + + +ignore = ["re:.*lm_head"] + +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + QuantizationType, +) + +recipe = GPTQModifier( + targets="Linear", + config_groups={ + "config_group": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + strategy=QuantizationStrategy.GROUP, + group_size=128, + symmetric=True, + dynamic=False, + actorder="dynamic", + ), + ), + }, + ignore=ignore, + update_size=NUM_CALIBRATION_SAMPLES, + dampening_frac=0.5, +) + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + # dataset=ds, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, + data_collator=data_collator, +) + +# processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/qwen.py b/qwen.py deleted file mode 100644 index e087e9266..000000000 --- a/qwen.py +++ /dev/null @@ -1,108 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot -import os - -# Load model. -model_id = "Qwen/Qwen2-VL-2B-Instruct" -model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") -processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - -print("Loading dataset") -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:512]" - -NUM_CALIBRATION_SAMPLES = 1 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - -print("Preprocessing samples") -def preprocess(example): - """ - Preprocesses a single example from the dataset. - """ - # Example messages structure - messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What does the image show?"} - ] - } - ] - return { - "text": [processor.apply_chat_template( - messages, - add_generation_prompt=True, - )], - } - -ds = ds.map(preprocess, remove_columns=["caption", "sentids", "img_id", "filename"]) - - -# Tokenize inputs. -def tokenize(sample): - image = sample.pop("image") - return processor( - **sample, - images=[image], - add_special_tokens=False, - return_tensors="pt" - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -def data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"]), # torch.Size([14308, 1176]) - "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), - } - -print("Setting up quantization params") -# Configure the quantization algorithm and scheme. -# In this case, we: -# * quantize the weights to fp8 with per channel via ptq -# * quantize the activations to fp8 with dynamic per token -#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] -#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], -ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] - -recipe = [ - # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), -] - -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") -oneshot( - model=model, - tokenizer=model_id, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - output_dir=save_path, - data_collator=data_collator, -) - -#processor.save_pretrained(save_path) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(processor.decode(output[0])) -print("==========================================") diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 11b7bbfa9..565fe0e70 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from compressed_tensors.registry import RegistryMixin -from datasets import Dataset, DatasetDict, IterableDataset +from datasets import Dataset, IterableDataset from loguru import logger from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments @@ -14,9 +14,8 @@ from llmcompressor.transformers.utils.preprocessing_functions import ( PreprocessingFunctionRegistry, ) -from llmcompressor.utils import Processor, import_from_path - -DatasetType = Union[Dataset, DatasetDict, IterableDataset] +from llmcompressor.utils import import_from_path +from llmcompressor.utils.typing import DatasetType, Processor class TextGenerationDataset(RegistryMixin): @@ -106,7 +105,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset = self.map( dataset, self.tokenize, - batched=True, + batched=False, remove_columns=dataset.column_names, num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, @@ -172,7 +171,6 @@ def load_dataset(self): @cached_property def preprocess(self) -> Union[Callable[[Any], Any], None]: """ - The function must return keys which correspond to tokenizer kwargs, optionally including PROMPT_KEY """ @@ -217,6 +215,7 @@ def tokenize(self, data: Dict[str, Any]) -> Dict[str, Any]: padding=self.padding, max_length=self.max_seq_length, truncation=True, + return_tensors="pt", ) # store unpadded prompt so we can mask out correct number of elements in labels diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 1b0b589f9..21119d756 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,5 +1,6 @@ import logging import os +import warnings from typing import Any, Callable, Dict, List, Optional import torch @@ -7,6 +8,8 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator +from llmcompressor.utils.typing import DatasetType + LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 @@ -92,7 +95,7 @@ def get_raw_dataset( def make_dataset_splits( - tokenized_datasets: Dict[str, Any], + datasets: Dict[str, Any], do_train: bool = False, do_eval: bool = False, do_predict: bool = False, @@ -102,7 +105,7 @@ def make_dataset_splits( Restructures the datasets dictionary based on what tasks will be run (train, eval, predict) - :param tokenized_datasets: dictionary of processed datasets + :param datasets: dictionary of processed datasets :param do_train: Whether to store the train dataset :param do_eval: Whether to store the validation dataset :param do_predict: Whether to store the test dataset @@ -111,31 +114,32 @@ def make_dataset_splits( """ # handles case where all splits are contained in a single dataset - if "all" in tokenized_datasets and len(tokenized_datasets) == 1: - tokenized_datasets = tokenized_datasets.get("all") - if isinstance(tokenized_datasets, Dataset): - tokenized_datasets = {"train": tokenized_datasets} + if "all" in datasets and len(datasets) == 1: + datasets = datasets.get("all") + if isinstance(datasets, Dataset): + datasets = {"train": datasets} train_split = eval_split = predict_split = calib_split = None if do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_split = tokenized_datasets["train"] + train_split = _get_split_with_fallbacks( + datasets, "train", ["train"], strict=True + ) if do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_split = tokenized_datasets["validation"] + eval_split = _get_split_with_fallbacks( + datasets, "evaluation", ["validation", "test"], strict=True + ) if do_predict: - if "test" not in tokenized_datasets: - raise ValueError("--do_predict requires a test dataset") - predict_split = tokenized_datasets["test"] + predict_split = _get_split_with_fallbacks( + datasets, "prediction", ["test", "validation"], strict=True + ) if do_oneshot: - calib_split = tokenized_datasets.get("calibration") - if calib_split is None: - if "train" not in tokenized_datasets: - raise ValueError("--do_oneshot requires a calibration dataset") - calib_split = tokenized_datasets["train"] + calib_split = _get_split_with_fallbacks( + datasets, + "oneshot", + ["calibration", "train", "test", "validation"], + strict=False, + ) # remove labels from calibration dataset column_names = calib_split.column_names @@ -250,3 +254,34 @@ def do_transform(candidate: str) -> bool: transform_dataset_key(dataset_key) return data_files + + +def _get_split_with_fallbacks( + datasets: Dict[str, DatasetType], + task: str, + fallbacks: List[str], + strict: bool = True, +) -> DatasetType: + assert len(fallbacks) > 0 + if len(datasets) <= 0: + raise ValueError("Cannot get retrieve data from dataset with no splits") + + # check first choice + first_choice = fallbacks[0] + if first_choice in datasets: + return datasets[first_choice] + + # last fallback is first available split + if not strict: + fallbacks.append(next(iter(datasets.keys()))) + + # check fallbacks + for fallback in fallbacks[1:]: + if fallback in datasets: + warnings.warn( + f"{task} expects a {first_choice} dataset split, " + f"falling back to {fallback}" + ) + return datasets[fallback] + + raise ValueError(f"{task} expects at least one of {fallbacks} dataset splits") diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 8802ccd5a..1514e5682 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -47,15 +47,21 @@ def dataset_template(self, sample): if self.processor is None: raise ValueError("TODO") - breakpoint() - - messages = sample["messages"] - if messages[0]["role"] != "system": - messages.insert(0, {"role": "system", "content": ""}) - + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"}, + ], + } + ] return { - "text": self.processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ), - "image": sample["image"], + "text": [ + self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + ) + ], + "images": sample["image"], } diff --git a/src/llmcompressor/utils/typing.py b/src/llmcompressor/utils/typing.py index 2f4bf4fc8..1050f7138 100644 --- a/src/llmcompressor/utils/typing.py +++ b/src/llmcompressor/utils/typing.py @@ -1,5 +1,6 @@ from typing import Union +from datasets import Dataset, DatasetDict, IterableDataset from transformers import ( BaseImageProcessor, FeatureExtractionMixin, @@ -11,3 +12,6 @@ Processor = Union[ PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin ] + +# Supported dataset types, IterableDataset is a streamed dataset +DatasetType = Union[Dataset, DatasetDict, IterableDataset] diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index 6bd0312b1..b676cfb70 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -1,13 +1,20 @@ import unittest import pytest +import torch from datasets import IterableDataset, load_dataset from parameterized import parameterized -from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.transformers import ( + DataTrainingArguments, + ModelArguments, + TextGenerationDataset, + TrainingArguments, +) +from llmcompressor.transformers.finetune.data.data_helpers import ( + format_calibration_data, +) from llmcompressor.transformers.finetune.runner import StageRunner -from llmcompressor.transformers.finetune.training_args import TrainingArguments from tests.testing_utils import requires_torch @@ -270,8 +277,6 @@ def prepare_fixture(self, tiny_llama_tokenizer): [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] ) def test_split_loading(self, split_def): - from llmcompressor.transformers.finetune.model_args import ModelArguments - data_args = DataTrainingArguments( dataset="open_platypus", splits=split_def, @@ -301,12 +306,6 @@ def prepare_fixture(self, tiny_llama_tokenizer): self.dataset = dataset.shuffle(seed=42).select(range(self.num_calib_samples)) def test_load_tokenized_data(self): - import torch - - from llmcompressor.transformers.finetune.data.data_helpers import ( - format_calibration_data, - ) - def preprocess(sample): concat_text = "INPUT: " + sample.get("input", "") concat_text += "INSTRUCTIONS: " + sample.get("instruction", "") From 5276c5896bf0b218c21633557fe037880a106001 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 22:19:15 +0000 Subject: [PATCH 133/285] discover bug, tests and multimodal working --- examples/multimodal_vision/mllama.py | 62 +++++++++++++++++++ examples/multimodal_vision/qwen_vl2.py | 20 +++--- .../transformers/finetune/data/base.py | 15 +++-- .../transformers/finetune/data/flickr_30k.py | 10 ++- 4 files changed, 83 insertions(+), 24 deletions(-) create mode 100644 examples/multimodal_vision/mllama.py diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py new file mode 100644 index 000000000..fe19119ce --- /dev/null +++ b/examples/multimodal_vision/mllama.py @@ -0,0 +1,62 @@ +import torch +from transformers import AutoProcessor, MllamaForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +import os + +# Load model. +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = "test[:512]" +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + +# TODO: define real collators in utils +def data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), + "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), + "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), + } + + +# Recipe +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], update_size=NUM_CALIBRATION_SAMPLES), +] + +# Perform oneshot +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, + data_collator=data_collator, +) + +processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/examples/multimodal_vision/qwen_vl2.py b/examples/multimodal_vision/qwen_vl2.py index c8cfb6e73..c55051bc0 100644 --- a/examples/multimodal_vision/qwen_vl2.py +++ b/examples/multimodal_vision/qwen_vl2.py @@ -8,18 +8,16 @@ # Load model. model_id = "Qwen/Qwen2-VL-2B-Instruct" -model = Qwen2VLForConditionalGeneration.from_pretrained( - model_id, device_map="auto", torch_dtype="auto" -) +model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) -# oneshot arguments +# Oneshot arguments DATASET_ID = "flickr30k" -DATASET_SPLIT = "test[:512]" +DATASET_SPLIT = "test[:3]" NUM_CALIBRATION_SAMPLES = 1 MAX_SEQUENCE_LENGTH = 2048 - +# TODO: define real collators in utils def data_collator(batch): assert len(batch) == 1 return { @@ -31,16 +29,13 @@ def data_collator(batch): "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), } - -ignore = ["re:.*lm_head"] - +# Recipe from compressed_tensors.quantization import ( QuantizationArgs, QuantizationScheme, QuantizationStrategy, QuantizationType, ) - recipe = GPTQModifier( targets="Linear", config_groups={ @@ -57,11 +52,12 @@ def data_collator(batch): ), ), }, - ignore=ignore, + ignore=["re:.*lm_head"], update_size=NUM_CALIBRATION_SAMPLES, dampening_frac=0.5, ) +# Perform oneshot save_name = model_id.split("/")[1] + "-W8A8" save_path = os.path.join("./my_test/", save_name) print("Starting quantization") @@ -79,7 +75,7 @@ def data_collator(batch): data_collator=data_collator, ) -# processor.save_pretrained(save_path) +processor.save_pretrained(save_path) # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 565fe0e70..b8b48d974 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,3 +1,4 @@ +import torch from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Union @@ -45,6 +46,8 @@ def __init__( self.data_args = data_args self.split = split self.processor = processor + #from transformers import AutoProcessor + #self.processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", trust_remote_code=True) # get tokenizer self.tokenizer = getattr(self.processor, "tokenizer", self.processor) @@ -105,7 +108,10 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset = self.map( dataset, self.tokenize, - batched=False, + batched=False, # batching is not well supported for vision processors + keep_in_memory=True, # bug occurs when not batched and not in memory, + # subsequent ds.map calls are always batched, + # regardless of `batched` argument remove_columns=dataset.column_names, num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, @@ -215,7 +221,6 @@ def tokenize(self, data: Dict[str, Any]) -> Dict[str, Any]: padding=self.padding, max_length=self.max_seq_length, truncation=True, - return_tensors="pt", ) # store unpadded prompt so we can mask out correct number of elements in labels @@ -259,7 +264,7 @@ def add_labels(self, data): def map( self, dataset: Union[Dataset, IterableDataset], - function: Union[Callable[[Any], Any], None], + function: Callable[[Any], Any], remove_columns: Optional[Union[str, List[str], Dict[str, List[str]]]] = None, **kwargs, ) -> Union[Dataset, IterableDataset]: @@ -269,14 +274,12 @@ def map( 1. Clears invalid parameters in the case where streaming is enabled 2. Skips removing columns which were already removed after mapping """ - if function is None: - return dataset - if isinstance(dataset, IterableDataset): # remove arguments that don't apply to streaming kwargs.pop("num_proc", None) kwargs.pop("load_from_cache_file", None) kwargs.pop("desc", None) + kwargs.pop("keep_in_memory", None) dataset = dataset.map(function, **kwargs) diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 1514e5682..112c69647 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -57,11 +57,9 @@ def dataset_template(self, sample): } ] return { - "text": [ - self.processor.apply_chat_template( - messages, - add_generation_prompt=True, - ) - ], + "text": self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), "images": sample["image"], } From dffcbc30a0fa9c67ecbd5e2b28755bb71a386794 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 22:50:52 +0000 Subject: [PATCH 134/285] dataset split fallbacks Signed-off-by: Kyle Sayers --- .../finetune/data/data_helpers.py | 84 ++++++++++++++----- src/llmcompressor/typing.py | 17 ++++ .../finetune/data/test_dataset_helpers.py | 30 ++++++- 3 files changed, 107 insertions(+), 24 deletions(-) create mode 100644 src/llmcompressor/typing.py diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..148cf85af 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,5 +1,6 @@ import logging import os +import warnings from typing import Any, Callable, Dict, List, Optional import torch @@ -7,6 +8,8 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator +from llmcompressor.typing import DatasetType + LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 @@ -92,17 +95,17 @@ def get_raw_dataset( def make_dataset_splits( - tokenized_datasets: Dict[str, Any], + datasets: Dict[str, DatasetType], do_train: bool = False, do_eval: bool = False, do_predict: bool = False, do_oneshot: bool = False, -) -> Dict[str, Dataset]: +) -> Dict[str, DatasetType]: """ Restructures the datasets dictionary based on what tasks will be run (train, eval, predict) - :param tokenized_datasets: dictionary of processed datasets + :param datasets: dictionary of processed datasets :param do_train: Whether to store the train dataset :param do_eval: Whether to store the validation dataset :param do_predict: Whether to store the test dataset @@ -111,31 +114,39 @@ def make_dataset_splits( """ # handles case where all splits are contained in a single dataset - if "all" in tokenized_datasets and len(tokenized_datasets) == 1: - tokenized_datasets = tokenized_datasets.get("all") - if isinstance(tokenized_datasets, Dataset): - tokenized_datasets = {"train": tokenized_datasets} + if "all" in datasets and len(datasets) == 1: + datasets = datasets.get("all") + if isinstance(datasets, Dataset): + datasets = {"train": datasets} train_split = eval_split = predict_split = calib_split = None if do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_split = tokenized_datasets["train"] + train_split = _get_split_with_fallbacks( + datasets, "train", ["train"], strict=True + ) if do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_split = tokenized_datasets["validation"] + eval_split = _get_split_with_fallbacks( + datasets, "evaluation", ["validation", "test"], strict=True + ) if do_predict: - if "test" not in tokenized_datasets: - raise ValueError("--do_predict requires a test dataset") - predict_split = tokenized_datasets["test"] + predict_split = _get_split_with_fallbacks( + datasets, "prediction", ["test", "validation"], strict=True + ) if do_oneshot: - calib_split = tokenized_datasets.get("calibration") - if calib_split is None: - if "train" not in tokenized_datasets: - raise ValueError("--do_oneshot requires a calibration dataset") - calib_split = tokenized_datasets["train"] + calib_split = _get_split_with_fallbacks( + datasets, + "oneshot", + ["calibration", "train", "test", "validation"], + strict=False, + ) + + # remove labels from calibration dataset + column_names = calib_split.column_names + if isinstance(column_names, dict): + column_names = sum(column_names.values(), []) + if "labels" in column_names: + calib_split = calib_split.remove_columns("labels") split_datasets = { "train": train_split, @@ -243,3 +254,34 @@ def do_transform(candidate: str) -> bool: transform_dataset_key(dataset_key) return data_files + + +def _get_split_with_fallbacks( + datasets: Dict[str, DatasetType], + task: str, + fallbacks: List[str], + strict: bool = True, +) -> DatasetType: + assert len(fallbacks) > 0 + if len(datasets) <= 0: + raise ValueError("Cannot get retrieve data from dataset with no splits") + + # check first choice + first_choice = fallbacks[0] + if first_choice in datasets: + return datasets[first_choice] + + # last fallback is first available split + if not strict: + fallbacks.append(next(iter(datasets.keys()))) + + # check fallbacks + for fallback in fallbacks[1:]: + if fallback in datasets: + warnings.warn( + f"{task} expects a {first_choice} dataset split, " + f"falling back to {fallback}" + ) + return datasets[fallback] + + raise ValueError(f"{task} expects at least one of {fallbacks} dataset splits") diff --git a/src/llmcompressor/typing.py b/src/llmcompressor/typing.py new file mode 100644 index 000000000..1050f7138 --- /dev/null +++ b/src/llmcompressor/typing.py @@ -0,0 +1,17 @@ +from typing import Union + +from datasets import Dataset, DatasetDict, IterableDataset +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedTokenizer, + ProcessorMixin, +) + +# Tokenizer or Processor. Processors do not inherit from a unified base class +Processor = Union[ + PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin +] + +# Supported dataset types, IterableDataset is a streamed dataset +DatasetType = Union[Dataset, DatasetDict, IterableDataset] diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 812b26a56..5229ea735 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock + import pytest from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments @@ -48,8 +50,30 @@ def test_separate_datasets(): assert split_datasets.get("validation") is not None assert split_datasets.get("test") is None + +@pytest.mark.unit +def test_datasets_fallbacks(): + # strict splits + mock_datasets = {"calibration": Mock(ds_name="calibration_ds", column_names=[])} + with pytest.raises(ValueError): + _ = make_dataset_splits(mock_datasets, do_train=True) with pytest.raises(ValueError): - # fails due to no test split specified - split_datasets = make_dataset_splits( - datasets, do_train=True, do_eval=True, do_predict=True + _ = make_dataset_splits(mock_datasets, do_eval=True) + with pytest.raises(ValueError): + _ = make_dataset_splits(mock_datasets, do_predict=True) + + # validation, predict, and oneshot fallbacks + mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])} + with pytest.warns(UserWarning): + split_ds = make_dataset_splits( + mock_datasets, do_eval=True, do_predict=True, do_oneshot=True ) + assert split_ds.get("validation").ds_name == "test_ds" + assert split_ds.get("test").ds_name == "test_ds" + assert split_ds.get("calibration").ds_name == "test_ds" + + # oneshot will take any dataset + mock_datasets = {"custom_split": Mock(ds_name="custom_ds", column_names=[])} + with pytest.warns(UserWarning): + split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) + assert split_ds.get("calibration").ds_name == "custom_ds" From e9f150d8043896b8ca0d3ab92ecb24c6035b7272 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 23:25:20 +0000 Subject: [PATCH 135/285] move typing --- examples/multimodal_vision/mllama.py | 15 ++++++++++++--- examples/multimodal_vision/qwen_vl2.py | 7 ++++++- .../modifiers/quantization/gptq/base.py | 3 +-- .../gptq/utils/partitioned_model.py | 12 ++++++++++-- src/llmcompressor/pipelines/peicewise.py | 15 ++++++++++++--- src/llmcompressor/pytorch/model_load/helpers.py | 2 +- .../transformers/finetune/data/base.py | 12 ++++++------ .../transformers/finetune/data/c4.py | 2 +- .../transformers/finetune/data/cnn_dailymail.py | 2 +- .../finetune/data/evolcodealpaca.py | 2 +- .../transformers/finetune/data/flickr_30k.py | 2 +- .../transformers/finetune/data/gsm8k.py | 2 +- .../transformers/finetune/data/open_platypus.py | 2 +- .../transformers/finetune/data/ptb.py | 2 +- .../finetune/data/ultrachat_200k.py | 2 +- .../transformers/finetune/data/wikitext.py | 2 +- .../transformers/finetune/runner.py | 2 +- .../transformers/finetune/text_generation.py | 2 +- src/llmcompressor/utils/__init__.py | 1 - src/llmcompressor/utils/typing.py | 17 ----------------- 20 files changed, 59 insertions(+), 47 deletions(-) delete mode 100644 src/llmcompressor/utils/typing.py diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index fe19119ce..46f9acabe 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -1,13 +1,16 @@ +import os + import torch from transformers import AutoProcessor, MllamaForConditionalGeneration from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot -import os # Load model. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +model = MllamaForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Oneshot arguments @@ -16,6 +19,7 @@ NUM_CALIBRATION_SAMPLES = 1 MAX_SEQUENCE_LENGTH = 2048 + # TODO: define real collators in utils def data_collator(batch): assert len(batch) == 1 @@ -32,7 +36,12 @@ def data_collator(batch): # Recipe recipe = [ # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], update_size=NUM_CALIBRATION_SAMPLES), + GPTQModifier( + targets="Linear", + scheme="W8A8", + ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], + update_size=NUM_CALIBRATION_SAMPLES, + ), ] # Perform oneshot diff --git a/examples/multimodal_vision/qwen_vl2.py b/examples/multimodal_vision/qwen_vl2.py index c55051bc0..d6e10b10b 100644 --- a/examples/multimodal_vision/qwen_vl2.py +++ b/examples/multimodal_vision/qwen_vl2.py @@ -8,7 +8,9 @@ # Load model. model_id = "Qwen/Qwen2-VL-2B-Instruct" -model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +model = Qwen2VLForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Oneshot arguments @@ -17,6 +19,7 @@ NUM_CALIBRATION_SAMPLES = 1 MAX_SEQUENCE_LENGTH = 2048 + # TODO: define real collators in utils def data_collator(batch): assert len(batch) == 1 @@ -29,6 +32,7 @@ def data_collator(batch): "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), } + # Recipe from compressed_tensors.quantization import ( QuantizationArgs, @@ -36,6 +40,7 @@ def data_collator(batch): QuantizationStrategy, QuantizationType, ) + recipe = GPTQModifier( targets="Linear", config_groups={ diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b25fafb6b..defbb5ee3 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -17,7 +17,7 @@ quantize_weight, ) from llmcompressor.modifiers.quantization.gptq.utils.partitioned_model import ( - PartitionedModel + PartitionedModel, ) from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -305,7 +305,6 @@ def _maybe_offload_hessians(self, module: torch.nn.Module): if self.offload_hessians: self._hessians[module] = self._hessians[module].to(device=device) - def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py index c81c34c97..e1f796a2b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py @@ -208,8 +208,16 @@ def is_leaf_module( concrete_args[parameter.name] = value - #self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) - self.graph: GraphModule = torch.fx.GraphModule(model, CustomTracer().trace(model, dummy_inputs=model.dummy_inputs, concrete_args=concrete_args, complete_concrete_args_with_inputs_not_in_dummy_inputs=False)) + # self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) + self.graph: GraphModule = torch.fx.GraphModule( + model, + CustomTracer().trace( + model, + dummy_inputs=model.dummy_inputs, + concrete_args=concrete_args, + complete_concrete_args_with_inputs_not_in_dummy_inputs=False, + ), + ) self.graph.config = model.config self.graph.class_for_deserialization = model.__class__ self.graph.device = model.device diff --git a/src/llmcompressor/pipelines/peicewise.py b/src/llmcompressor/pipelines/peicewise.py index cbd839cc0..c4e94e26b 100644 --- a/src/llmcompressor/pipelines/peicewise.py +++ b/src/llmcompressor/pipelines/peicewise.py @@ -1,13 +1,20 @@ import contextlib -import torch +import torch from datasets import Dataset from llmcompressor.core.session_functions import initialize from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException from llmcompressor.recipe.recipe import Recipe -from llmcompressor.utils.helpers import calibration_forward_context, trace_subgraphs, get_targets, get_model_device, tensors_to_device, create_dataloader +from llmcompressor.utils.helpers import ( + calibration_forward_context, + create_dataloader, + get_model_device, + get_targets, + tensors_to_device, + trace_subgraphs, +) def run_pipeline( @@ -26,7 +33,9 @@ def run_pipeline( # create dataloader model_device = get_model_device(model) - dataloader = create_dataloader(dataset, batch_size=..., mask_padding=True, model_device=model_device) + dataloader = create_dataloader( + dataset, batch_size=..., mask_padding=True, model_device=model_device + ) with calibration_forward_context(model): # prepare intermediates cache diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 15c2adfeb..a9ecb67a7 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -9,7 +9,7 @@ from llmcompressor.core import active_session, create_session, pre_initialize_structure from llmcompressor.pytorch.utils import ModuleSparsificationInfo -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor COMPLETED_STAGES_FILENAME = "completed_stages.json" diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index b8b48d974..5d93d2366 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,7 +1,7 @@ -import torch from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Union +import torch from compressed_tensors.registry import RegistryMixin from datasets import Dataset, IterableDataset from loguru import logger @@ -15,8 +15,8 @@ from llmcompressor.transformers.utils.preprocessing_functions import ( PreprocessingFunctionRegistry, ) +from llmcompressor.typing import DatasetType, Processor from llmcompressor.utils import import_from_path -from llmcompressor.utils.typing import DatasetType, Processor class TextGenerationDataset(RegistryMixin): @@ -46,8 +46,8 @@ def __init__( self.data_args = data_args self.split = split self.processor = processor - #from transformers import AutoProcessor - #self.processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", trust_remote_code=True) + # from transformers import AutoProcessor + # self.processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", trust_remote_code=True) # get tokenizer self.tokenizer = getattr(self.processor, "tokenizer", self.processor) @@ -110,8 +110,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: self.tokenize, batched=False, # batching is not well supported for vision processors keep_in_memory=True, # bug occurs when not batched and not in memory, - # subsequent ds.map calls are always batched, - # regardless of `batched` argument + # subsequent ds.map calls are always batched, + # regardless of `batched` argument remove_columns=dataset.column_names, num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index 7f02c0eaf..e50d4d0c6 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index b4e6484a0..06ad3ecfa 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 784cd55fe..932bfa54c 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 112c69647..03fd0fa7b 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index 36324fa87..beae5dfec 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index b4a136ba0..3b25986ca 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index 6b8bd3b61..c7f0bbac1 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 3e5a30b7b..b11272de2 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index 7ea6953f6..a559399d8 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor if TYPE_CHECKING: from llmcompressor.transformers import DataTrainingArguments as DataArgs diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 368da499a..138ad94ff 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -23,7 +23,7 @@ ) from llmcompressor.transformers.finetune.model_args import ModelArguments from llmcompressor.transformers.finetune.training_args import TrainingArguments -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index ffea2df9d..2f1b44b17 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -52,7 +52,7 @@ get_shared_processor_src, ) from llmcompressor.transformers.utils.helpers import detect_last_checkpoint -from llmcompressor.utils import Processor +from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model diff --git a/src/llmcompressor/utils/__init__.py b/src/llmcompressor/utils/__init__.py index d36055acb..98d5e1c65 100644 --- a/src/llmcompressor/utils/__init__.py +++ b/src/llmcompressor/utils/__init__.py @@ -5,4 +5,3 @@ # flake8: noqa from .helpers import * -from .typing import * diff --git a/src/llmcompressor/utils/typing.py b/src/llmcompressor/utils/typing.py deleted file mode 100644 index 1050f7138..000000000 --- a/src/llmcompressor/utils/typing.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Union - -from datasets import Dataset, DatasetDict, IterableDataset -from transformers import ( - BaseImageProcessor, - FeatureExtractionMixin, - PreTrainedTokenizer, - ProcessorMixin, -) - -# Tokenizer or Processor. Processors do not inherit from a unified base class -Processor = Union[ - PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin -] - -# Supported dataset types, IterableDataset is a streamed dataset -DatasetType = Union[Dataset, DatasetDict, IterableDataset] From d061567ca575f22257616404addd7581fadfe02d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Dec 2024 23:56:09 +0000 Subject: [PATCH 136/285] cleanup, depreciate remove_columns argument --- examples/multimodal_vision/mllama.py | 15 ++++++++++++--- examples/multimodal_vision/qwen_vl2.py | 7 ++++++- .../transformers/finetune/data/base.py | 8 +++----- .../transformers/finetune/data/data_args.py | 4 +++- .../transformers/finetune/text_generation.py | 11 +++++++++++ .../transformers/oneshot/dataset_processing.py | 3 --- .../transformers/oneshot/test_api_inputs.py | 5 +---- 7 files changed, 36 insertions(+), 17 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index fe19119ce..46f9acabe 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -1,13 +1,16 @@ +import os + import torch from transformers import AutoProcessor, MllamaForConditionalGeneration from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot -import os # Load model. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +model = MllamaForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Oneshot arguments @@ -16,6 +19,7 @@ NUM_CALIBRATION_SAMPLES = 1 MAX_SEQUENCE_LENGTH = 2048 + # TODO: define real collators in utils def data_collator(batch): assert len(batch) == 1 @@ -32,7 +36,12 @@ def data_collator(batch): # Recipe recipe = [ # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], update_size=NUM_CALIBRATION_SAMPLES), + GPTQModifier( + targets="Linear", + scheme="W8A8", + ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], + update_size=NUM_CALIBRATION_SAMPLES, + ), ] # Perform oneshot diff --git a/examples/multimodal_vision/qwen_vl2.py b/examples/multimodal_vision/qwen_vl2.py index c55051bc0..d6e10b10b 100644 --- a/examples/multimodal_vision/qwen_vl2.py +++ b/examples/multimodal_vision/qwen_vl2.py @@ -8,7 +8,9 @@ # Load model. model_id = "Qwen/Qwen2-VL-2B-Instruct" -model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") +model = Qwen2VLForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Oneshot arguments @@ -17,6 +19,7 @@ NUM_CALIBRATION_SAMPLES = 1 MAX_SEQUENCE_LENGTH = 2048 + # TODO: define real collators in utils def data_collator(batch): assert len(batch) == 1 @@ -29,6 +32,7 @@ def data_collator(batch): "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), } + # Recipe from compressed_tensors.quantization import ( QuantizationArgs, @@ -36,6 +40,7 @@ def data_collator(batch): QuantizationStrategy, QuantizationType, ) + recipe = GPTQModifier( targets="Linear", config_groups={ diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index b8b48d974..971a91915 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,4 +1,3 @@ -import torch from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Union @@ -46,8 +45,6 @@ def __init__( self.data_args = data_args self.split = split self.processor = processor - #from transformers import AutoProcessor - #self.processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", trust_remote_code=True) # get tokenizer self.tokenizer = getattr(self.processor, "tokenizer", self.processor) @@ -110,8 +107,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: self.tokenize, batched=False, # batching is not well supported for vision processors keep_in_memory=True, # bug occurs when not batched and not in memory, - # subsequent ds.map calls are always batched, - # regardless of `batched` argument + # subsequent ds.map calls are always batched, + # regardless of `batched` argument remove_columns=dataset.column_names, num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, @@ -286,6 +283,7 @@ def map( if isinstance(dataset, IterableDataset): dataset = dataset._resolve_features() + # remove columns which are present, skip removing those which are not if remove_columns is not None: if isinstance(remove_columns, str): remove_columns = [remove_columns] diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index c2340549f..a02975ef8 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -41,7 +41,7 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): remove_columns: Union[None, str, List] = field( default=None, metadata={ - "help": "For custom datasets only. Column names to remove after " + "help": "This argument is depreciated. Column names to remove after " "preprocessing custom datasets" }, ) @@ -51,6 +51,8 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): metadata={ "help": ( "For custom datasets only. Either a function to apply to the dataset, " + "a function name defined in " + "src/llmcompressor/transformers/utils/preprocessing_functions.py, or " "a path to a function definition of the form /path/to/file.py:func" ) }, diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index ffea2df9d..f6bf41a20 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -18,6 +18,7 @@ # neuralmagic: no copyright import os +import warnings from pathlib import PosixPath from loguru import logger @@ -118,6 +119,8 @@ def parse_args(**kwargs): * model_args in src/llmcompressor/transformers/finetune/model_args.py * data_args in src/llmcompressor/transformers/finetune/data/data_args.py * training_args in src/llmcompressor/transformers/finetune/training_args.py + + Throws depreciation warnings """ parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments) @@ -135,6 +138,14 @@ def parse_args(**kwargs): arg_dict[key] = value training_args.recipe_args = arg_dict + # raise depreciation warnings + if data_args.remove_columns is not None: + warnings.warn( + "`remove_columns` argument is depreciated, when processing non-tokenized " + "datasets, all columns not returned by preprocessing_fn will be removed", + DeprecationWarning + ) + return model_args, data_args, training_args diff --git a/tests/llmcompressor/transformers/oneshot/dataset_processing.py b/tests/llmcompressor/transformers/oneshot/dataset_processing.py index 4def97917..e2d16da52 100644 --- a/tests/llmcompressor/transformers/oneshot/dataset_processing.py +++ b/tests/llmcompressor/transformers/oneshot/dataset_processing.py @@ -56,7 +56,6 @@ def get_data_utils(dataset_name: str) -> Dict: Includes: 1. dataload: function to load the dataset 2. preprocess: preprocessing function to apply to the dataset - 3. remove_columns: specific columns which should be removed from the dataset :param dataset_name: the name of the dataset :returns dictionary of preprocessing functions/utils. @@ -65,12 +64,10 @@ def get_data_utils(dataset_name: str) -> Dict: "open_platypus": { "preprocess": _preprocess_alpaca, "dataload": _fetch_open_platypus_dataset, - "remove_columns": ["input", "output", "instruction", "data_source"], }, "gsm8k": { "preprocess": _preprocess_gsm, "dataload": _fetch_gsm8k_data, - "remove_columns": ["question", "answer"], }, } return data_mapping.get(dataset_name) diff --git a/tests/llmcompressor/transformers/oneshot/test_api_inputs.py b/tests/llmcompressor/transformers/oneshot/test_api_inputs.py index 3e3ee2147..216b91356 100644 --- a/tests/llmcompressor/transformers/oneshot/test_api_inputs.py +++ b/tests/llmcompressor/transformers/oneshot/test_api_inputs.py @@ -45,10 +45,7 @@ def wrapped_preprocess_func(sample): # to the loaded dataset. if self.tokenize: loaded_dataset = data_utils.get("dataload")() - self.dataset = loaded_dataset.map( - wrapped_preprocess_func, - remove_columns=data_utils.get("remove_columns"), - ) + self.dataset = loaded_dataset.map(wrapped_preprocess_func) self.tokenizer = None def test_one_shot_inputs(self): From 55a31ca32f669bcde3899ac9d090e7fdf77ad996 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 00:08:02 +0000 Subject: [PATCH 137/285] silently assign tokenizer to processor --- examples/multimodal_vision/qwen_vl2.py | 13 ++++++------- .../transformers/finetune/model_args.py | 3 +++ .../transformers/finetune/text_generation.py | 9 ++++++++- src/llmcompressor/utils/__init__.py | 1 - 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/multimodal_vision/qwen_vl2.py b/examples/multimodal_vision/qwen_vl2.py index d6e10b10b..2de17d4ae 100644 --- a/examples/multimodal_vision/qwen_vl2.py +++ b/examples/multimodal_vision/qwen_vl2.py @@ -1,6 +1,12 @@ import os import torch +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + QuantizationType, +) from transformers import AutoProcessor, Qwen2VLForConditionalGeneration from llmcompressor.modifiers.quantization import GPTQModifier @@ -34,13 +40,6 @@ def data_collator(batch): # Recipe -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationScheme, - QuantizationStrategy, - QuantizationType, -) - recipe = GPTQModifier( targets="Linear", config_groups={ diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/finetune/model_args.py index 8fbcad70e..2df5a7f5d 100644 --- a/src/llmcompressor/transformers/finetune/model_args.py +++ b/src/llmcompressor/transformers/finetune/model_args.py @@ -37,6 +37,9 @@ class ModelArguments: ) processor: Optional[str] = field( default=None, + metadata={ + "help": "Pretrained processor name or path if not the same as model_name" + }, ) cache_dir: Optional[str] = field( default=None, diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index f6bf41a20..cd39a2e84 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -143,9 +143,16 @@ def parse_args(**kwargs): warnings.warn( "`remove_columns` argument is depreciated, when processing non-tokenized " "datasets, all columns not returned by preprocessing_fn will be removed", - DeprecationWarning + DeprecationWarning, ) + # silently assign tokenizer to processor + if model_args.tokenizer: + if model_args.processor: + raise ValueError("Cannot use both a tokenizer and processor") + model_args.processor = model_args.tokenizer + model_args.tokenizer = None + return model_args, data_args, training_args diff --git a/src/llmcompressor/utils/__init__.py b/src/llmcompressor/utils/__init__.py index d36055acb..98d5e1c65 100644 --- a/src/llmcompressor/utils/__init__.py +++ b/src/llmcompressor/utils/__init__.py @@ -5,4 +5,3 @@ # flake8: noqa from .helpers import * -from .typing import * From 1aba16dc3ccd705b9e0986ed36fea5cd3d186e81 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 00:36:35 +0000 Subject: [PATCH 138/285] replace tokenizer with processor Signed-off-by: Kyle Sayers --- .../pytorch/model_load/helpers.py | 12 +++-- .../transformers/finetune/data/base.py | 44 +++++++++++----- .../transformers/finetune/data/c4.py | 6 +-- .../finetune/data/cnn_dailymail.py | 6 +-- .../transformers/finetune/data/custom.py | 6 +-- .../finetune/data/evolcodealpaca.py | 6 +-- .../transformers/finetune/data/gsm8k.py | 6 +-- .../finetune/data/open_platypus.py | 6 +-- .../transformers/finetune/data/ptb.py | 6 +-- .../finetune/data/ultrachat_200k.py | 10 ++-- .../transformers/finetune/data/wikitext.py | 6 +-- .../transformers/finetune/model_args.py | 6 +++ .../transformers/finetune/runner.py | 16 +++--- .../transformers/finetune/session_mixin.py | 5 +- .../transformers/finetune/text_generation.py | 52 +++++++++++-------- .../compressed_tensors_utils.py | 5 +- .../sparsification/sparse_model.py | 8 +-- .../utils/preprocessing_functions.py | 7 ++- src/llmcompressor/typing.py | 17 ++++++ src/llmcompressor/utils/fsdp/helpers.py | 7 +-- .../compression/test_quantization.py | 2 +- .../finetune/data/test_dataset_loading.py | 20 +++---- .../finetune/data/test_registry.py | 6 +-- .../transformers/obcq/test_obcq_completion.py | 2 +- tests/testing_utils.py | 4 +- 25 files changed, 165 insertions(+), 106 deletions(-) create mode 100644 src/llmcompressor/typing.py diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 3db9be173..a9ecb67a7 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -9,6 +9,7 @@ from llmcompressor.core import active_session, create_session, pre_initialize_structure from llmcompressor.pytorch.utils import ModuleSparsificationInfo +from llmcompressor.typing import Processor COMPLETED_STAGES_FILENAME = "completed_stages.json" @@ -92,15 +93,16 @@ def initialize_recipe(model: Module, recipe_path: str): def save_model_and_recipe( model: Module, save_path: str, - tokenizer: Optional[Any] = None, + processor: Optional[Processor] = None, save_safetensors: bool = False, save_compressed: bool = False, ): """ - Save a model, tokenizer and the currently loaded recipe to file + Save a model, processor and the currently loaded recipe to file + :param model: pytorch model to save :param save_path: path to save output to - :param tokenizer: model tokenizer to save + :param processor: model processor or tokenizer to save :param save_safetensors: whether to save as safetensors or pickle (bin) :param save_compressed: whether to compress sparse weights on disk """ @@ -111,8 +113,8 @@ def save_model_and_recipe( save_path, save_compressed=save_compressed, safe_serialization=save_safetensors ) - if tokenizer is not None: - tokenizer.save_pretrained(save_path) + if processor is not None: + processor.save_pretrained(save_path) logger.info("Saving output to {}".format(os.path.abspath(save_path))) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index d4c3a6222..3b68e0fc1 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -3,7 +3,6 @@ from compressed_tensors.registry import RegistryMixin from datasets import Dataset, IterableDataset from loguru import logger -from transformers import AutoTokenizer from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( @@ -11,6 +10,7 @@ get_custom_datasets_from_path, get_raw_dataset, ) +from llmcompressor.typing import Processor class TextGenerationDataset(RegistryMixin): @@ -30,10 +30,10 @@ def __init__( text_column: str, data_args: DataTrainingArguments, split: str, - tokenizer: AutoTokenizer, + processor: Processor, ): self.text_column = text_column - self.tokenizer = tokenizer + self.processor = processor self.data_args = data_args self.raw_kwargs = data_args.raw_kwargs or {} self.split = split @@ -50,20 +50,38 @@ def __init__( else: self.padding = False - if self.tokenizer: + # get tokenizer + self.tokenizer = getattr(self.processor, "tokenizer", self.processor) + + if self.tokenizer is not None: + # fill in pad token if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token - # configure sequence length - max_seq_length = data_args.max_seq_length - model_max_length = tokenizer.model_max_length if tokenizer else max_seq_length - if self.tokenizer and max_seq_length > model_max_length: - logger.warning( - f"The max_seq_length passed ({max_seq_length}) is larger than " - f"the maximum length for the model ({tokenizer.model_max_length}). " - f"Using max_seq_length={tokenizer.model_max_length}." + # configure sequence length + max_seq_length = data_args.max_seq_length + if data_args.max_seq_length > self.tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({max_seq_length}) is larger than " + f"maximum length for model ({self.tokenizer.model_max_length}). " + f"Using max_seq_length={self.tokenizer.model_max_length}." + ) + self.max_seq_length = min( + data_args.max_seq_length, self.tokenizer.model_max_length + ) + + # configure padding + self.padding = ( + False + if self.data_args.concatenate_data + else "max_length" + if self.data_args.pad_to_max_length + else False ) - self.max_seq_length = min(data_args.max_seq_length, model_max_length) + + else: + self.max_seq_length = None + self.padding = False def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset: """ diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index 37eeceae6..91cbc58e8 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -10,12 +10,12 @@ class C4Dataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "allenai/c4" super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer + text_column="text", data_args=data_args, split=split, processor=processor ) diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 64755de4a..dcebe7573 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -24,18 +24,18 @@ class CNNDailyMailDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "cnn_dailymail" data_args.dataset_config_name = "3.0.0" super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer + text_column="text", data_args=data_args, split=split, processor=processor ) def get_raw_dataset(self, cache_dir: Optional[str] = None): diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py index e849594e7..817cb34de 100644 --- a/src/llmcompressor/transformers/finetune/data/custom.py +++ b/src/llmcompressor/transformers/finetune/data/custom.py @@ -32,17 +32,17 @@ class CustomDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` Can also be set to None to load all the splits - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) super().__init__( text_column=data_args.text_column, data_args=data_args, split=split, - tokenizer=tokenizer, + processor=processor, ) self.preprocessing_func = data_args.preprocessing_func self.remove_columns = data_args.remove_columns diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 9529d3115..66505f117 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -24,7 +24,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ EVOL_ALPACA_TEMPLATE = ( @@ -34,11 +34,11 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "theblackcat102/evol-codealpaca-v1" super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer + text_column="text", data_args=data_args, split=split, processor=processor ) def get_raw_dataset(self, cache_dir: Optional[str] = None): diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index f9a94bcf4..299ae1bb2 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -11,16 +11,16 @@ class GSM8KDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "gsm8k" super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer + text_column="text", data_args=data_args, split=split, processor=processor ) def get_raw_dataset(self, cache_dir: Optional[str] = None): diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 55e54cbce..7a17c6fde 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -24,7 +24,7 @@ class OpenPlatypusDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ ALPACA_TEMPLATE = { @@ -37,11 +37,11 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "garage-bAInd/Open-Platypus" super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer + text_column="text", data_args=data_args, split=split, processor=processor ) def get_raw_dataset(self, cache_dir: Optional[str] = None): diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index 6f502edaf..8519f023c 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -10,15 +10,15 @@ class PtbDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "ptb_text_only" super().__init__( text_column="sentence", data_args=data_args, split=split, - tokenizer=tokenizer, + processor=processor, ) diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 5b2e66ab5..30607847d 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -24,7 +24,7 @@ class UltraChatDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ DEFAULT_CHAT_TEMPLATE = ( @@ -40,7 +40,7 @@ class UltraChatDataset(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): data_args = deepcopy(data_args) data_args.dataset = "HuggingFaceH4/ultrachat_200k" @@ -51,13 +51,15 @@ def __init__(self, data_args, split, tokenizer): text_column="messages", data_args=data_args, split=split, - tokenizer=tokenizer, + processor=processor, ) if ( not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None ): + # note that since tokenizer is a member of processor, + # this change affects processor.apply_chat_template self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE def get_raw_dataset(self, cache_dir: Optional[str] = None): @@ -75,7 +77,7 @@ def restructure_fn(sample): if sample["messages"][0]["role"] != "system": sample["messages"].insert(0, {"role": "system", "content": ""}) - sample["messages"] = self.tokenizer.apply_chat_template( + sample["messages"] = self.processor.apply_chat_template( sample["messages"], tokenize=False, add_generation_prompt=False ) return sample diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index 034d58ba2..25280589c 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -8,10 +8,10 @@ class WikiTextDataset(TextGenerationDataset): :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset + :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args, split, processor): super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer + text_column="text", data_args=data_args, split=split, processor=processor ) diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/finetune/model_args.py index d3d8e974f..c81900ee2 100644 --- a/src/llmcompressor/transformers/finetune/model_args.py +++ b/src/llmcompressor/transformers/finetune/model_args.py @@ -34,6 +34,12 @@ class ModelArguments: "help": "Pretrained tokenizer name or path if not the same as model_name" }, ) + processor: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained processor name or path if not the same as model_name" + }, + ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where to store the pretrained data from huggingface.co"}, diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 6344b1a2b..131180199 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -6,7 +6,6 @@ import torch from loguru import logger from torch.utils.data import Dataset -from transformers import AutoTokenizer from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import ( @@ -24,6 +23,7 @@ ) from llmcompressor.transformers.finetune.model_args import ModelArguments from llmcompressor.transformers.finetune.training_args import TrainingArguments +from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe @@ -38,7 +38,7 @@ class StageRunner: - set_trainer() - train() / evaluate() / predict() - :param model_args: Arguments pertaining to model/config/tokenizer + :param model_args: Arguments pertaining to model/config/processor :param data_args: Arguments pertaining to what data to use for different flows :param training_args: Arguments pertaining to training loop configuration :model: unwrapped model to run flows on @@ -56,11 +56,11 @@ def __init__( self.datasets = {} self.trainer = None - self.tokenizer = None + self.processor = None self.parent_output_dir = self._training_args.output_dir self._output_dir = self._training_args.output_dir - def populate_datasets(self, tokenizer: "AutoTokenizer", add_labels: bool = True): + def populate_datasets(self, processor: Processor, add_labels: bool = True): """ Loads datasets for each flow based on data_args, stores a Dataset for each enabled flow in self.datasets @@ -68,7 +68,7 @@ def populate_datasets(self, tokenizer: "AutoTokenizer", add_labels: bool = True) :param tokenizer: tokenizer to use for dataset tokenization """ if self._data_args.dataset is None: - self.tokenizer = self._model_args.tokenizer + self.processor = self._model_args.processor logger.info( "Running oneshot without calibration data. This is expected for " "weight-only and dynamic quantization" @@ -102,7 +102,7 @@ def _get_split_name(inp_str): registry_id, data_args=self._data_args, split=split_str, - tokenizer=tokenizer, + processor=processor, ) dataset = self._data_args.dataset @@ -124,7 +124,7 @@ def _get_split_name(inp_str): do_predict=self._training_args.do_predict, do_oneshot=self._training_args.do_oneshot, ) - self.tokenizer = tokenizer + self.processor = processor def get_dataset_split(self, split_name: str) -> Dataset: """ @@ -266,7 +266,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): save_model_and_recipe( model=self.trainer.model, save_path=self._output_dir, - tokenizer=self.tokenizer, + processor=self.processor, save_safetensors=self._training_args.save_safetensors, save_compressed=self._training_args.save_compressed, ) diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index b1ac57b95..27860aeb4 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -487,8 +487,9 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): ) self.save_state() - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) + processor = getattr(self, "processing_class", self.tokenizer) + if processor is not None: + processor.save_pretrained(output_dir) if not self.recipe: return diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 85aa6d82c..a6c21fc39 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -24,9 +24,9 @@ from transformers import ( AutoConfig, AutoModelForCausalLM, - AutoTokenizer, - DefaultDataCollator, + AutoProcessor, HfArgumentParser, + PreTrainedModel, set_seed, ) @@ -49,9 +49,10 @@ patch_tied_tensors_bug, ) from llmcompressor.transformers.sparsification.sparse_model import ( - get_shared_tokenizer_src, + get_shared_processor_src, ) from llmcompressor.transformers.utils.helpers import detect_last_checkpoint +from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model @@ -134,6 +135,13 @@ def parse_args(**kwargs): arg_dict[key] = value training_args.recipe_args = arg_dict + # silently assign tokenizer to processor + if model_args.tokenizer: + if model_args.processor: + raise ValueError("Cannot use both a tokenizer and processor") + model_args.processor = model_args.tokenizer + model_args.tokenizer = None + return model_args, data_args, training_args @@ -226,11 +234,13 @@ def initialize_model_from_path( return teacher, model_path, model -def initialize_tokenizer_from_path(model_args, model, teacher): - tokenizer_src = model_args.tokenizer - tokenizer_src = tokenizer_src or get_shared_tokenizer_src(model, teacher) - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_src, +def initialize_processor_from_path( + model_args: ModelArguments, model: PreTrainedModel, teacher: PreTrainedModel +) -> Processor: + processor_src = model_args.processor + processor_src = processor_src or get_shared_processor_src(model, teacher) + processor = AutoProcessor.from_pretrained( + processor_src, cache_dir=model_args.cache_dir, use_fast=True, revision=model_args.model_revision, @@ -238,7 +248,7 @@ def initialize_tokenizer_from_path(model_args, model, teacher): trust_remote_code=model_args.trust_remote_code_model, ) - return tokenizer + return processor def main( @@ -299,11 +309,9 @@ def main( # Detecting last checkpoint. last_checkpoint = None teacher = model_args.distill_teacher - model = model_args.model - # Load tokenizer - # distill TODO: support for different tokenizer for teacher? - tokenizer = model_args.tokenizer + # distill TODO: support for different processor for teacher? + model = model_args.model if isinstance(model, str) or isinstance(model, PosixPath): (teacher, _model_path, model) = initialize_model_from_path( model_args, @@ -317,8 +325,9 @@ def main( if teacher is not None: teacher.eval() - if isinstance(tokenizer, str) or tokenizer is None: - tokenizer = initialize_tokenizer_from_path(model_args, model, teacher) + processor = model_args.processor + if isinstance(processor, str) or processor is None: + processor = initialize_processor_from_path(model_args, model, teacher) pre_initialize_structure(model=model) @@ -330,13 +339,12 @@ def main( model_args=model_args, data_args=data_args, training_args=training_args ) add_labels = training_args.do_train or training_args.run_stages - stage_runner.populate_datasets(tokenizer=tokenizer, add_labels=add_labels) + stage_runner.populate_datasets(processor=processor, add_labels=add_labels) train_dataset = stage_runner.get_dataset_split("train") eval_dataset = stage_runner.get_dataset_split("validation") calib_dataset = stage_runner.get_dataset_split("calibration") # Initialize our Trainer - data_collator = DefaultDataCollator() trainer = Trainer( model_init=get_session_model, teacher=teacher, @@ -346,13 +354,13 @@ def main( data_args=data_args, train_dataset=train_dataset or calib_dataset, eval_dataset=eval_dataset, - tokenizer=tokenizer, - data_collator=data_collator, + processing_class=processor, + data_collator=data_args.data_collator, ) # wrap model.save_pretrained if is_fsdp_model(model): - modify_fsdp_model_save_pretrained(trainer, tokenizer) + modify_fsdp_model_save_pretrained(trainer, processor) else: modify_save_pretrained(model) @@ -396,8 +404,8 @@ def main( model.save_pretrained( training_args.output_dir, save_compressed=training_args.save_compressed ) - if tokenizer is not None: - tokenizer.save_pretrained(training_args.output_dir) + if processor is not None: + processor.save_pretrained(training_args.output_dir) # Clean up the CompressionSession before exit if requested if training_args.clear_sparse_session: diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 759098894..ce4ae7fb2 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -25,6 +25,7 @@ SparsityConfigMetadata, ) from llmcompressor.transformers.utils import RECIPE_FILE_NAME +from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import ( find_and_move_state_dicts_to_cpu, unwrap_and_export_model, @@ -33,7 +34,7 @@ __all__ = ["modify_save_pretrained", "modify_fsdp_model_save_pretrained"] -def modify_fsdp_model_save_pretrained(trainer, tokenizer): +def modify_fsdp_model_save_pretrained(trainer, processor: Processor): """ Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that supports compression for fsdp model @@ -78,7 +79,7 @@ def save_pretrained_wrapper( model=trainer.model, accelerator=trainer.accelerator, output_dir=save_directory, - tokenizer=tokenizer, + processor=processor, ) # only allow the main process move the state # dicts to cpu diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index bf09396d7..d7abc323a 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -7,7 +7,7 @@ __all__ = [ "SparseAutoModelForCausalLM", - "get_shared_tokenizer_src", + "get_shared_processor_src", ] @@ -20,14 +20,14 @@ def from_pretrained(*args, **kwargs): return AutoModelForCausalLM.from_pretrained(*args, **kwargs) -def get_shared_tokenizer_src(student: Module, teacher: Optional[Module]) -> str: +def get_shared_processor_src(student: Module, teacher: Optional[Module]) -> str: """ - Get a tokenizer source used for both student and teacher, assuming + Get a processor/tokenizer source used for both student and teacher, assuming that they could be shared :param student: the student model :param teacher: the teacher model - :return: the source for the tokenizer shared between teacher and model + :return: the source for the processor/tokenizer shared between teacher and model """ if teacher is not None and teacher not in ("disable", "self"): diff --git a/src/llmcompressor/transformers/utils/preprocessing_functions.py b/src/llmcompressor/transformers/utils/preprocessing_functions.py index cadec88f0..6bf6ade42 100644 --- a/src/llmcompressor/transformers/utils/preprocessing_functions.py +++ b/src/llmcompressor/transformers/utils/preprocessing_functions.py @@ -1,14 +1,17 @@ -from typing import Dict +from typing import TYPE_CHECKING, Dict from compressed_tensors.registry import RegistryMixin +if TYPE_CHECKING: + from llmcompressor.transformers.finetune.data.base import TextGenerationDataset + class PreprocessingFunctionRegistry(RegistryMixin): pass @PreprocessingFunctionRegistry.register() -def custom_evolved_codealpaca_dataset(data: Dict): +def custom_evolved_codealpaca_dataset(self: "TextGenerationDataset", data: Dict): PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:""" data["prompt"] = PROMPT_DICT.format_map(data) data["text"] = data["prompt"] + data["output"] diff --git a/src/llmcompressor/typing.py b/src/llmcompressor/typing.py new file mode 100644 index 000000000..1050f7138 --- /dev/null +++ b/src/llmcompressor/typing.py @@ -0,0 +1,17 @@ +from typing import Union + +from datasets import Dataset, DatasetDict, IterableDataset +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedTokenizer, + ProcessorMixin, +) + +# Tokenizer or Processor. Processors do not inherit from a unified base class +Processor = Union[ + PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin +] + +# Supported dataset types, IterableDataset is a streamed dataset +DatasetType = Union[Dataset, DatasetDict, IterableDataset] diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 8cc0f5405..3a3248fa5 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -18,6 +18,7 @@ from llmcompressor.core.state import State from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe +from llmcompressor.typing import Processor from llmcompressor.utils.pytorch import set_layer __all__ = [ @@ -71,7 +72,7 @@ def set_wrapped_model(state: State, wrapped_model: Module): state.model = wrapped_model -def unwrap_and_export_model(model, accelerator, output_dir, tokenizer): +def unwrap_and_export_model(model, accelerator, output_dir: str, processor: Processor): """ Recursively unwraps an FSDP model, then saves the unwrapped model and the currently active recipe to disk @@ -79,7 +80,7 @@ def unwrap_and_export_model(model, accelerator, output_dir, tokenizer): :param model: model to unwrap :param accelerator: Accelerator instance used to perform unwrapping :param output_dir: where to save output model - :param tokenizer: tokenizer used by the model + :param processor: processor used by the model """ full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FullyShardedDataParallel.state_dict_type( @@ -95,7 +96,7 @@ def unwrap_and_export_model(model, accelerator, output_dir, tokenizer): save_model_and_recipe( model=unwrapped_model, save_path=output_dir, - tokenizer=tokenizer, + processor=processor, ) diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 2dd1249d6..4a37f138d 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -133,7 +133,7 @@ def _get_dataloader(self, data_args, tokenizer): data_args.dataset, data_args=data_args, split="train_gen[:5%]", - tokenizer=tokenizer, + processor=tokenizer, ) calib_dataset = dataset_manager.tokenize_and_process( dataset_manager.get_raw_dataset() diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index 3415858af..cbedd5b9d 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -29,7 +29,7 @@ def test_concatenation_tokenization(self): self.data_args.dataset, data_args=self.data_args, split="train[:5%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) raw_dataset = wiki_manager.get_raw_dataset() self.assertGreater(len(raw_dataset), 0) @@ -61,7 +61,7 @@ def test_no_padding_tokenization(self): self.data_args.dataset, data_args=self.data_args, split="train[5%:10%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) raw_dataset = op_manager.get_raw_dataset() self.assertGreater(len(raw_dataset), 0) @@ -96,7 +96,7 @@ def test_max_seq_len_clipped(self): self.data_args.dataset, data_args=self.data_args, split="train[80%:]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) self.assertEqual( @@ -125,7 +125,7 @@ def test_dataset_kwargs_and_percentages(self): self.data_args.dataset, data_args=self.data_args, split="train[5%:10%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) raw_dataset_a = c4_manager_a.get_raw_dataset() @@ -133,7 +133,7 @@ def test_dataset_kwargs_and_percentages(self): self.data_args.dataset, data_args=self.data_args, split="train[5%:15%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) raw_dataset_b = c4_manager_b.get_raw_dataset() @@ -164,7 +164,7 @@ def test_datasets(self, dataset_key, dataset_config, split, do_concat): data_args.dataset, data_args=data_args, split=split, - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) raw_dataset = manager.get_raw_dataset() self.assertGreater(len(raw_dataset), 0) @@ -204,7 +204,7 @@ def test_evol(self): self.data_args.dataset, data_args=self.data_args, split="train[:2%]", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) raw_dataset = evol_manager.get_raw_dataset() self.assertGreater(len(raw_dataset), 0) @@ -238,7 +238,7 @@ def test_stream_loading(self): self.data_args.dataset, data_args=self.data_args, split="train", - tokenizer=self.tiny_llama_tokenizer, + processor=self.tiny_llama_tokenizer, ) raw_dataset = manager.get_raw_dataset() @@ -276,7 +276,7 @@ def test_split_loading(self, split_def): stage_runner = StageRunner( model_args=model_args, data_args=data_args, training_args=training_args ) - stage_runner.populate_datasets(tokenizer=self.tiny_llama_tokenizer) + stage_runner.populate_datasets(processor=self.tiny_llama_tokenizer) train_dataset = stage_runner.get_dataset_split("train") assert train_dataset is not None @@ -320,7 +320,7 @@ def preprocess(sample): ), training_args=TrainingArguments(do_oneshot=True), ) - stage_runner.populate_datasets(tokenizer=None) + stage_runner.populate_datasets(processor=None) calib_dataset = stage_runner.get_dataset_split("calibration") self.assertEqual(len(calib_dataset), self.num_calib_samples) data_cols = calib_dataset.column_names diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index e4c804c07..3350d0a79 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -16,7 +16,7 @@ def test_c4_initializes(tiny_llama_tokenizer): data_args.dataset, data_args=data_args, split=None, - tokenizer=tiny_llama_tokenizer, + processor=tiny_llama_tokenizer, ) assert isinstance(c4_manager, TextGenerationDataset) assert isinstance(c4_manager, C4Dataset) @@ -34,7 +34,7 @@ def test_wikitext_initializes(tiny_llama_tokenizer): data_args.dataset, data_args=data_args, split=None, - tokenizer=tiny_llama_tokenizer, + processor=tiny_llama_tokenizer, ) assert isinstance(wiki_manager, TextGenerationDataset) assert isinstance(wiki_manager, WikiTextDataset) @@ -50,7 +50,7 @@ def test_open_platypus_initializes(tiny_llama_tokenizer): data_args.dataset, data_args=data_args, split=None, - tokenizer=tiny_llama_tokenizer, + processor=tiny_llama_tokenizer, ) assert isinstance(op_manager, TextGenerationDataset) assert isinstance(op_manager, OpenPlatypusDataset) diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index 03517de07..096c8df94 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -37,7 +37,7 @@ def labeled_dataloader(self, dataset_name, model_name): data_args.dataset, data_args=data_args, split="train", - tokenizer=tokenizer, + processor=tokenizer, ) calib_dataset = dataset_manager.tokenize_and_process( dataset_manager.get_raw_dataset() diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 7a5dab66f..c28a25545 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -9,7 +9,7 @@ import yaml from datasets import Dataset -from transformers import AutoTokenizer +from transformers import PreTrainedTokenizer from tests.data import CustomTestConfig, TestConfig @@ -130,7 +130,7 @@ def run_cli_command(cmd: List[str], cwd: Optional[Union[str, Path]] = None): def preprocess_tokenize_dataset( - ds: Dataset, tokenizer: AutoTokenizer, max_seq_length: int + ds: Dataset, tokenizer: PreTrainedTokenizer, max_seq_length: int ) -> Dataset: """ Helper function to preprocess and tokenize a dataset according to presets From 89bda306ff43a96c7577fafc313e17d417487af7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 04:22:28 +0000 Subject: [PATCH 139/285] defer data collator changes Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/text_generation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index a6c21fc39..f0e3a6b16 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -25,6 +25,7 @@ AutoConfig, AutoModelForCausalLM, AutoProcessor, + DefaultDataCollator, HfArgumentParser, PreTrainedModel, set_seed, @@ -345,6 +346,7 @@ def main( calib_dataset = stage_runner.get_dataset_split("calibration") # Initialize our Trainer + data_collator = DefaultDataCollator() trainer = Trainer( model_init=get_session_model, teacher=teacher, @@ -355,7 +357,7 @@ def main( train_dataset=train_dataset or calib_dataset, eval_dataset=eval_dataset, processing_class=processor, - data_collator=data_args.data_collator, + data_collator=data_collator, ) # wrap model.save_pretrained From 0fa410214ddbefd0cd99789885316d784b4a74bf Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 04:41:20 +0000 Subject: [PATCH 140/285] reduce warnings Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/data_helpers.py | 2 +- tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 148cf85af..9f14518f5 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -117,7 +117,7 @@ def make_dataset_splits( if "all" in datasets and len(datasets) == 1: datasets = datasets.get("all") if isinstance(datasets, Dataset): - datasets = {"train": datasets} + datasets = {"train": datasets, "calibration": datasets} # shallow copy train_split = eval_split = predict_split = calib_split = None diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py b/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py index 0e80b6d0c..926ce5fbe 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py @@ -13,6 +13,7 @@ @requires_torch @pytest.mark.integration +@pytest.mark.filterwarnings("ignore::UserWarning") @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestSparsities(unittest.TestCase): model = None @@ -61,6 +62,7 @@ def tearDown(self): @requires_gpu @requires_torch @pytest.mark.integration +@pytest.mark.filterwarnings("ignore::UserWarning") @parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) class TestSparsitiesGPU(unittest.TestCase): model = None From bc505bf31815f037eccf95d9ff797cb9b166028b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 04:53:42 +0000 Subject: [PATCH 141/285] typehinting, add not-implemented error Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/base.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 425cc388d..ea3c71e7b 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -3,6 +3,7 @@ from compressed_tensors.registry import RegistryMixin from datasets import Dataset, IterableDataset +from datasets.formatting.formatting import LazyRow from loguru import logger from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments @@ -14,8 +15,8 @@ from llmcompressor.transformers.utils.preprocessing_functions import ( PreprocessingFunctionRegistry, ) +from llmcompressor.typing import DatasetType, Processor from llmcompressor.utils import import_from_path -from llmcompressor.typing import Processor, DatasetType class TextGenerationDataset(RegistryMixin): @@ -172,7 +173,7 @@ def load_dataset(self): ) @cached_property - def preprocess(self) -> Union[Callable[[Any], Any], None]: + def preprocess(self) -> Union[Callable[[LazyRow], Any], None]: """ The function must return keys which correspond to tokenizer kwargs, optionally including PROMPT_KEY @@ -208,7 +209,7 @@ def rename_columns(self, dataset: DatasetType) -> DatasetType: return dataset - def tokenize(self, data: Dict[str, Any]) -> Dict[str, Any]: + def tokenize(self, data: LazyRow) -> Dict[str, Any]: # separate prompt prompt = data.pop(self.PROMPT_KEY, None) @@ -230,7 +231,7 @@ def tokenize(self, data: Dict[str, Any]) -> Dict[str, Any]: return data - def group_text(self, data: Dict[str, Any]) -> Dict[str, Any]: + def group_text(self, data: LazyRow) -> Dict[str, Any]: concatenated_data = {k: sum(data[k], []) for k in data.keys()} total_length = len(concatenated_data[list(data.keys())[0]]) total_length = (total_length // self.max_seq_length) * self.max_seq_length @@ -243,7 +244,12 @@ def group_text(self, data: Dict[str, Any]) -> Dict[str, Any]: } return result - def add_labels(self, data): + def add_labels(self, data: LazyRow) -> LazyRow: + if "pixel_values" in data: + raise NotImplementedError( + "Label masking for vision datasets has not been implemented yet" + ) + # if the dataset uses prompts, mask them out so they don't contribute # to the loss calculation prompt_len = 0 From c91ba773666cd7ba8136344932cdc9183d22e7db Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 04:55:27 +0000 Subject: [PATCH 142/285] remove todos Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/flickr_30k.py | 3 --- src/llmcompressor/transformers/finetune/data/ultrachat_200k.py | 3 --- src/llmcompressor/transformers/finetune/model_args.py | 1 - 3 files changed, 7 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 03fd0fa7b..2c55bf42d 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -44,9 +44,6 @@ def __init__(self, data_args: "DataArgs", split: str, processor: Processor): self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE def dataset_template(self, sample): - if self.processor is None: - raise ValueError("TODO") - messages = [ { "role": "user", diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index b11272de2..47af48cdd 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -50,9 +50,6 @@ def __init__(self, data_args: "DataArgs", split: str, processor: Processor): self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE def dataset_template(self, sample): - if self.processor is None: - raise ValueError("TODO") - messages = sample["messages"] if messages[0]["role"] != "system": messages.insert(0, {"role": "system", "content": ""}) diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/finetune/model_args.py index 2df5a7f5d..c81900ee2 100644 --- a/src/llmcompressor/transformers/finetune/model_args.py +++ b/src/llmcompressor/transformers/finetune/model_args.py @@ -28,7 +28,6 @@ class ModelArguments: "help": "Pretrained config name or path if not the same as model_name" }, ) - # TODO: depreciate tokenizer: Optional[str] = field( default=None, metadata={ From e916936412b769ef7366c43cf8fbf39230ab8f06 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 00:03:53 -0500 Subject: [PATCH 143/285] Delete mllama.py --- mllama.py | 104 ------------------------------------------------------ 1 file changed, 104 deletions(-) delete mode 100644 mllama.py diff --git a/mllama.py b/mllama.py deleted file mode 100644 index 30e27e12f..000000000 --- a/mllama.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration, AutoModelForCausalLM, Qwen2VLForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot -import os - -# Load model. -model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype="auto") -processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - -print("Loading dataset") -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:512]" - -NUM_CALIBRATION_SAMPLES = 1 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - -print("Preprocessing samples") -def preprocess(example): - messages = [ - [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What does the image show?"} - ] - } - ], - ] - return { - "text": processor.apply_chat_template( - messages, - add_generation_prompt=True, - ), - } - -ds = ds.map(preprocess, remove_columns=["caption", "sentids", "img_id", "filename"]) - - -# Tokenize inputs. -def tokenize(sample): - image = sample.pop("image") - return processor( - **sample, - images=[image], - add_special_tokens=False, - return_tensors="pt" - ) - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -def data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"]), - "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), - "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), - "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), - } - - - -print("Setting up quantization params") -# CHANGE THIS IF YOU WANT TO QUANTIZE THE VISION TOWER -ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] - -recipe = [ - # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore, update_size=NUM_CALIBRATION_SAMPLES), -] - -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") -oneshot( - model=model, - tokenizer=model_id, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - output_dir=save_path, - data_collator=data_collator, -) - -#processor.save_pretrained(save_path) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(processor.decode(output[0])) -print("==========================================") From 0a573a1263a8c92bfab1b599ac3bd77e4f6dc2d5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 05:11:04 +0000 Subject: [PATCH 144/285] update dataset manager api in tests Signed-off-by: Kyle Sayers --- .../transformers/compression/test_quantization.py | 4 +--- tests/llmcompressor/transformers/obcq/test_obcq_completion.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 4a37f138d..a9debf907 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -135,9 +135,7 @@ def _get_dataloader(self, data_args, tokenizer): split="train_gen[:5%]", processor=tokenizer, ) - calib_dataset = dataset_manager.tokenize_and_process( - dataset_manager.get_raw_dataset() - ) + calib_dataset = dataset_manager() data_loader = DataLoader( calib_dataset, batch_size=1, diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index 096c8df94..667b45adb 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -39,9 +39,7 @@ def labeled_dataloader(self, dataset_name, model_name): split="train", processor=tokenizer, ) - calib_dataset = dataset_manager.tokenize_and_process( - dataset_manager.get_raw_dataset() - ) + calib_dataset = dataset_manager() data_loader = DataLoader( calib_dataset, batch_size=1, collate_fn=DefaultDataCollator() ) From 853c0a81d402757010e05b9ac19716d6bde185e9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 04:53:42 +0000 Subject: [PATCH 145/285] typehinting, add not-implemented error Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/base.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 0ec8e8dd6..48e5face6 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -4,6 +4,7 @@ import torch from compressed_tensors.registry import RegistryMixin from datasets import Dataset, IterableDataset +from datasets.formatting.formatting import LazyRow from loguru import logger from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments @@ -173,7 +174,7 @@ def load_dataset(self): ) @cached_property - def preprocess(self) -> Union[Callable[[Any], Any], None]: + def preprocess(self) -> Union[Callable[[LazyRow], Any], None]: """ The function must return keys which correspond to tokenizer kwargs, optionally including PROMPT_KEY @@ -209,7 +210,7 @@ def rename_columns(self, dataset: DatasetType) -> DatasetType: return dataset - def tokenize(self, data: Dict[str, Any]) -> Dict[str, Any]: + def tokenize(self, data: LazyRow) -> Dict[str, Any]: # separate prompt prompt = data.pop(self.PROMPT_KEY, None) @@ -231,7 +232,7 @@ def tokenize(self, data: Dict[str, Any]) -> Dict[str, Any]: return data - def group_text(self, data: Dict[str, Any]) -> Dict[str, Any]: + def group_text(self, data: LazyRow) -> Dict[str, Any]: concatenated_data = {k: sum(data[k], []) for k in data.keys()} total_length = len(concatenated_data[list(data.keys())[0]]) total_length = (total_length // self.max_seq_length) * self.max_seq_length @@ -244,7 +245,12 @@ def group_text(self, data: Dict[str, Any]) -> Dict[str, Any]: } return result - def add_labels(self, data): + def add_labels(self, data: LazyRow) -> LazyRow: + if "pixel_values" in data: + raise NotImplementedError( + "Label masking for vision datasets has not been implemented yet" + ) + # if the dataset uses prompts, mask them out so they don't contribute # to the loss calculation prompt_len = 0 From 234ef79efc34eaa6bd91bb9e8ac054341d185f4a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 04:55:27 +0000 Subject: [PATCH 146/285] remove todos Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/flickr_30k.py | 3 --- src/llmcompressor/transformers/finetune/data/ultrachat_200k.py | 3 --- src/llmcompressor/transformers/finetune/model_args.py | 1 - 3 files changed, 7 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 03fd0fa7b..2c55bf42d 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -44,9 +44,6 @@ def __init__(self, data_args: "DataArgs", split: str, processor: Processor): self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE def dataset_template(self, sample): - if self.processor is None: - raise ValueError("TODO") - messages = [ { "role": "user", diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index b11272de2..47af48cdd 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -50,9 +50,6 @@ def __init__(self, data_args: "DataArgs", split: str, processor: Processor): self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE def dataset_template(self, sample): - if self.processor is None: - raise ValueError("TODO") - messages = sample["messages"] if messages[0]["role"] != "system": messages.insert(0, {"role": "system", "content": ""}) diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/finetune/model_args.py index 2df5a7f5d..c81900ee2 100644 --- a/src/llmcompressor/transformers/finetune/model_args.py +++ b/src/llmcompressor/transformers/finetune/model_args.py @@ -28,7 +28,6 @@ class ModelArguments: "help": "Pretrained config name or path if not the same as model_name" }, ) - # TODO: depreciate tokenizer: Optional[str] = field( default=None, metadata={ From 8972dd5226f3ff1ad602fdc6473a099cbeba8a4a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 05:11:04 +0000 Subject: [PATCH 147/285] update dataset manager api in tests Signed-off-by: Kyle Sayers --- .../transformers/compression/test_quantization.py | 4 +--- tests/llmcompressor/transformers/obcq/test_obcq_completion.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 4a37f138d..a9debf907 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -135,9 +135,7 @@ def _get_dataloader(self, data_args, tokenizer): split="train_gen[:5%]", processor=tokenizer, ) - calib_dataset = dataset_manager.tokenize_and_process( - dataset_manager.get_raw_dataset() - ) + calib_dataset = dataset_manager() data_loader = DataLoader( calib_dataset, batch_size=1, diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index 096c8df94..667b45adb 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -39,9 +39,7 @@ def labeled_dataloader(self, dataset_name, model_name): split="train", processor=tokenizer, ) - calib_dataset = dataset_manager.tokenize_and_process( - dataset_manager.get_raw_dataset() - ) + calib_dataset = dataset_manager() data_loader = DataLoader( calib_dataset, batch_size=1, collate_fn=DefaultDataCollator() ) From acb1a184b9aa577cfb0c0a7437d114730711f16c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 00:21:08 -0500 Subject: [PATCH 148/285] Delete examples/multimodal_vision/qwen_vl2.py --- examples/multimodal_vision/qwen_vl2.py | 89 -------------------------- 1 file changed, 89 deletions(-) delete mode 100644 examples/multimodal_vision/qwen_vl2.py diff --git a/examples/multimodal_vision/qwen_vl2.py b/examples/multimodal_vision/qwen_vl2.py deleted file mode 100644 index 2de17d4ae..000000000 --- a/examples/multimodal_vision/qwen_vl2.py +++ /dev/null @@ -1,89 +0,0 @@ -import os - -import torch -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationScheme, - QuantizationStrategy, - QuantizationType, -) -from transformers import AutoProcessor, Qwen2VLForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot - -# Load model. -model_id = "Qwen/Qwen2-VL-2B-Instruct" -model = Qwen2VLForConditionalGeneration.from_pretrained( - model_id, device_map="auto", torch_dtype="auto" -) -processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - -# Oneshot arguments -DATASET_ID = "flickr30k" -DATASET_SPLIT = "test[:3]" -NUM_CALIBRATION_SAMPLES = 1 -MAX_SEQUENCE_LENGTH = 2048 - - -# TODO: define real collators in utils -def data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor( - batch[0]["pixel_values"] - ), # torch.Size([14308, 1176]) - "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), - } - - -# Recipe -recipe = GPTQModifier( - targets="Linear", - config_groups={ - "config_group": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - type=QuantizationType.INT, - strategy=QuantizationStrategy.GROUP, - group_size=128, - symmetric=True, - dynamic=False, - actorder="dynamic", - ), - ), - }, - ignore=["re:.*lm_head"], - update_size=NUM_CALIBRATION_SAMPLES, - dampening_frac=0.5, -) - -# Perform oneshot -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") -oneshot( - model=model, - tokenizer=model_id, - # dataset=ds, - dataset=DATASET_ID, - splits=DATASET_SPLIT, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - output_dir=save_path, - data_collator=data_collator, -) - -processor.save_pretrained(save_path) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(processor.decode(output[0])) -print("==========================================") From 56b5d125e28caf9424658141999cd7a5cb5a8ac3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 00:21:17 -0500 Subject: [PATCH 149/285] Delete examples/multimodal_vision/mllama.py --- examples/multimodal_vision/mllama.py | 71 ---------------------------- 1 file changed, 71 deletions(-) delete mode 100644 examples/multimodal_vision/mllama.py diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py deleted file mode 100644 index 46f9acabe..000000000 --- a/examples/multimodal_vision/mllama.py +++ /dev/null @@ -1,71 +0,0 @@ -import os - -import torch -from transformers import AutoProcessor, MllamaForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot - -# Load model. -model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -model = MllamaForConditionalGeneration.from_pretrained( - model_id, device_map="auto", torch_dtype="auto" -) -processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - -# Oneshot arguments -DATASET_ID = "flickr30k" -DATASET_SPLIT = "test[:512]" -NUM_CALIBRATION_SAMPLES = 1 -MAX_SEQUENCE_LENGTH = 2048 - - -# TODO: define real collators in utils -def data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"]), - "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), - "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), - "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), - } - - -# Recipe -recipe = [ - # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), - GPTQModifier( - targets="Linear", - scheme="W8A8", - ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], - update_size=NUM_CALIBRATION_SAMPLES, - ), -] - -# Perform oneshot -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") -oneshot( - model=model, - tokenizer=model_id, - dataset=DATASET_ID, - splits=DATASET_SPLIT, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - output_dir=save_path, - data_collator=data_collator, -) - -processor.save_pretrained(save_path) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(processor.decode(output[0])) -print("==========================================") From 57c293e50f80490cae2f61675b19b118d2c61e4a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 19:54:01 +0000 Subject: [PATCH 150/285] WIP: add pixtral --- examples/multimodal_vision/pixtral.py | 72 +++++++++ src/llmcompressor/pytorch/tracing/llava.py | 179 +++++++++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 examples/multimodal_vision/pixtral.py create mode 100644 src/llmcompressor/pytorch/tracing/llava.py diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py new file mode 100644 index 000000000..8568da91d --- /dev/null +++ b/examples/multimodal_vision/pixtral.py @@ -0,0 +1,72 @@ +import os + +import torch +from transformers import AutoProcessor +from llmcompressor.pytorch.tracing.llava import TracableLlavaForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot + +# Load model. +model_id = "mgoin/pixtral-12b" +model = TracableLlavaForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = "test[:512]" +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + + +# TODO: define real collators in utils +def data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), + "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), + "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), + } + + +# Recipe +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier( + targets="Linear", + scheme="W8A8", + ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], + update_size=NUM_CALIBRATION_SAMPLES, + ), +] + +# Perform oneshot +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, + data_collator=data_collator, +) + +processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/src/llmcompressor/pytorch/tracing/llava.py b/src/llmcompressor/pytorch/tracing/llava.py new file mode 100644 index 000000000..e4d95c786 --- /dev/null +++ b/src/llmcompressor/pytorch/tracing/llava.py @@ -0,0 +1,179 @@ +from typing import Optional, List, Union, Tuple +from functools import wraps + +import torch +from transformers import LlavaForConditionalGeneration +from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast, logger + + +class TracableLlavaForConditionalGeneration(LlavaForConditionalGeneration): + @wraps(LlavaForConditionalGeneration.forward) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + legacy_processing = False + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing + # not very reliable, but we don't expect one to actually pass 500+ images for one prompt + # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True + + # NOT TRACABLE, instead always use legacy_processing = False + # legacy_processing = ( + # (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + # ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." + ) + # prefill stage vs decoding stage (legacy behavior copied) + if input_ids.shape[1] != 1: + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + elif image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) \ No newline at end of file From 537c5abd4f16241b8e89c5ae5470944301ab7554 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 15:11:20 -0500 Subject: [PATCH 151/285] pixtral working --- examples/multimodal_vision/pixtral.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 8568da91d..aa17c08b8 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -28,9 +28,6 @@ def data_collator(batch): "input_ids": torch.LongTensor(batch[0]["input_ids"]), "attention_mask": torch.tensor(batch[0]["attention_mask"]), "pixel_values": torch.tensor(batch[0]["pixel_values"]), - "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), - "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), - "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), } From 15b3508a2b83153fc7b579c37f44a4adb9465cf2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 03:04:20 +0000 Subject: [PATCH 152/285] move to data pipeline --- examples/multimodal_vision/pixtral.py | 2 +- pyproject.toml | 2 +- .../modifiers/quantization/gptq/base.py | 53 ++-- .../quantization/gptq/utils/gptq_quantize.py | 114 ++++--- .../gptq/utils/partitioned_model.py | 291 ------------------ .../modifiers/utils/pytorch_helpers.py | 4 +- .../pipelines/piecewise/__init__.py | 3 + .../pipelines/piecewise/helpers.py | 243 +++++++++++++++ .../{peicewise.py => piecewise/pipeline.py} | 64 ++-- src/llmcompressor/pytorch/__init__.py | 68 ++-- src/llmcompressor/pytorch/tracing/llava.py | 3 +- .../transformers/finetune/data/base.py | 1 - src/llmcompressor/utils/fsdp/helpers.py | 13 +- 13 files changed, 412 insertions(+), 449 deletions(-) delete mode 100644 src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py create mode 100644 src/llmcompressor/pipelines/piecewise/__init__.py create mode 100644 src/llmcompressor/pipelines/piecewise/helpers.py rename src/llmcompressor/pipelines/{peicewise.py => piecewise/pipeline.py} (55%) diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index aa17c08b8..14187d7d1 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -2,9 +2,9 @@ import torch from transformers import AutoProcessor -from llmcompressor.pytorch.tracing.llava import TracableLlavaForConditionalGeneration from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.pytorch.tracing.llava import TracableLlavaForConditionalGeneration from llmcompressor.transformers import oneshot # Load model. diff --git a/pyproject.toml b/pyproject.toml index 1baa7d2c0..98661216d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ profile = "black" files = "src/guidellm" [tool.ruff] -exclude = ["build", "dist", "env", ".venv"] +exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/pytorch/tracing"] lint.select = ["E", "F", "W"] [tool.flake8] diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index defbb5ee3..77cb56a7b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -16,23 +16,13 @@ make_empty_hessian, quantize_weight, ) -from llmcompressor.modifiers.quantization.gptq.utils.partitioned_model import ( - PartitionedModel, -) from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException -from llmcompressor.transformers.finetune.data.data_helpers import ( - create_batch_dataloader, -) +from llmcompressor.pipelines.piecewise import run_pipeline from llmcompressor.utils.fsdp.helpers import update_offload_parameter -from llmcompressor.utils.helpers import ( - align_module, - calibration_forward_context, - getattr_chain, -) +from llmcompressor.utils.helpers import align_module, getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active +from llmcompressor.utils.pytorch.module import qat_active __all__ = ["GPTQModifier"] @@ -106,7 +96,7 @@ class GPTQModifier(Modifier, HooksMixin): # gptq modifier arguments sequential_update: bool = True # DEPRECIATED - update_size: int = 1 + update_size: int = 512 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 @@ -204,21 +194,13 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - targets = get_no_split_params(state.model) - partitioned_model = PartitionedModel() - partitioned_model.init_forward( - state.model, targets, next(iter(state.data.calib)) - ) - # register hooks for name, module in state.model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: post_hook = partial(self.compress_module, name) self.register_hook(module, post_hook, "forward") - # feed data - with calibration_forward_context(state.model): - partitioned_model.forward_data(state.data.calib, mask_padding=True) + run_pipeline(state.model, None, state.data.calib, propagate_error=True) return True @@ -258,11 +240,12 @@ def compress_module( # Initialize hessian if not present if module not in self._num_samples: - self._hessians[module] = make_empty_hessian(module) + init_device = "cpu" if self.offload_hessians else module.weight.device + self._hessians[module] = make_empty_hessian(module, device=init_device) self._num_samples[module] = 0 # Accumulate hessian with input with optional offloading - with self._maybe_offload_hessians(module): + with self._maybe_onload_hessians(module): self._hessians[module], self._num_samples[module] = accumulate_hessian( inp, type(module), @@ -273,11 +256,15 @@ def compress_module( # After enough samples are accumulated, perform quantization if self._num_samples[module] >= self.update_size: logger.info(f"Quantizing {name} using {self._num_samples[module]} samples") - with align_module(module), CompressionLogger(module) as comp_logger: + with ( + align_module(module), + self._maybe_onload_hessians(module), + CompressionLogger(module) as comp_logger, + ): loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( - module.weight.data, - inp, - quant_args, + weight=module.weight.data, + quant_args=quant_args, + hessian=self._hessians[module], blocksize=self.block_size, percdamp=self.dampening_frac, module_class=type(module), @@ -296,15 +283,15 @@ def compress_module( comp_logger.set_loss(loss) @contextlib.contextmanager - def _maybe_offload_hessians(self, module: torch.nn.Module): + def _maybe_onload_hessians(self, module: torch.nn.Module): if self.offload_hessians: - device = self._hessians[module].device - self._hessians[module] = self._hessians[module].to(device="cpu") + device = module.weight.device + self._hessians[module] = self._hessians[module].to(device=device) yield if self.offload_hessians: - self._hessians[module] = self._hessians[module].to(device=device) + self._hessians[module] = self._hessians[module].to(device="cpu") def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index f01544330..7620dcb9a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -17,13 +17,16 @@ GPTQ_PRECISION = torch.float32 +__all__ = ["make_empty_hessian", "accumulate_hessian", "quantize_weight"] -def make_empty_hessian(module: torch.nn.Module): + +def make_empty_hessian( + module: torch.nn.Module, device: Optional[torch.device] = None +) -> torch.Tensor: weight = module.weight num_columns = weight.shape[1] - return torch.zeros( - (num_columns, num_columns), device=weight.device, dtype=GPTQ_PRECISION - ) + device = device if device is not None else weight.device + return torch.zeros((num_columns, num_columns), device=device, dtype=GPTQ_PRECISION) def accumulate_hessian( @@ -54,54 +57,11 @@ def accumulate_hessian( return H, num_samples -def compute_hessian( - inp: torch.Tensor, module_class: Type[torch.nn.Module], device -) -> torch.Tensor: - """ - Calculate the hessian with respect to the module inputs - - :param inp: module inputs - :param module_class: class of module, likely torch.nn.Linear - :return: hessian w.r.t. module inputs - """ - inp = inp.to(device=device) - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - - nsamples = inp.shape[0] # note this is the number of dataset samples, not - # multiplied by the sequence length - - if module_class in (torch.nn.Linear, transformers.Conv1D): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - - inp = inp.to(dtype=GPTQ_PRECISION) - inp = math.sqrt(2 / nsamples) * inp - return inp.matmul(inp.t()) - - -def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: - """ - Performs in-place inversion of the hessian in order to save memory - - :param H: hessian being inverted - :param percdamp: dampening factor on hessian diagonal - :return: inverted hessian - """ - damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(H.shape[0], device=H.device) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H, upper=True) - return H - - def quantize_weight( weight: torch.Tensor, - inp: torch.Tensor, quant_args: QuantizationArgs, + hessian: Optional[torch.Tensor] = None, + inp: Optional[torch.Tensor] = None, blocksize: int = 128, percdamp: float = 0.01, module_class: Type[torch.nn.Module] = torch.nn.Linear, @@ -144,7 +104,15 @@ def quantize_weight( num_rows = W.shape[0] num_columns = W.shape[1] - H = compute_hessian(inp, module_class, device=weight.device) + # compute hessian + if inp is not None: + if hessian is not None: + raise ValueError("Must pass either inp or hessian, but not both") + H = _compute_hessian(inp, module_class, device=weight.device) + elif hessian is not None: + H = hessian + else: + raise ValueError("Must pass either inp or hessian") if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index @@ -191,7 +159,7 @@ def quantize_weight( # compute inverse hessian in place to save memory # TODO: check in place - Hinv = invert_hessian(H, percdamp) + Hinv = _invert_hessian(H, percdamp) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, num_columns, blocksize): @@ -302,6 +270,50 @@ def quantize_weight( ) +def _invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: + """ + Performs in-place inversion of the hessian in order to save memory + + :param H: hessian being inverted + :param percdamp: dampening factor on hessian diagonal + :return: inverted hessian + """ + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + return H + + +def _compute_hessian( + inp: torch.Tensor, module_class: Type[torch.nn.Module], device +) -> torch.Tensor: + """ + Calculate the hessian with respect to the module inputs + + :param inp: module inputs + :param module_class: class of module, likely torch.nn.Linear + :return: hessian w.r.t. module inputs + """ + inp = inp.to(device=device) + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + nsamples = inp.shape[0] # note this is the number of dataset samples, not + # multiplied by the sequence length + + if module_class in (torch.nn.Linear, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + inp = inp.to(dtype=GPTQ_PRECISION) + inp = math.sqrt(2 / nsamples) * inp + return inp.matmul(inp.t()) + + def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py b/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py deleted file mode 100644 index e1f796a2b..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/partitioned_model.py +++ /dev/null @@ -1,291 +0,0 @@ -import contextlib -import inspect -from collections import deque -from typing import Any, Callable, Dict, List, Set, Tuple - -import torch -import tqdm -from accelerate.hooks import remove_hook_from_module -from torch.fx import Graph, GraphModule, Node -from transformers.utils.fx import HFTracer, symbolic_trace - -from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.modifiers.utils.pytorch_helpers import ( - EarlyStopException, - apply_pad_mask_to_batch, -) -from llmcompressor.pytorch.utils.helpers import tensors_to_device -from llmcompressor.utils.helpers import calibration_forward_context - - -def get_target_nodes(graph: GraphModule, targets: List[str]): - target_nodes = [] - for node in graph.graph.nodes: - if ( - node.op == "call_module" - and type(graph.get_submodule(node.target)).__name__ in targets - ): - target_nodes.append(node) - - return target_nodes - - -def check_assumption(graph: Graph) -> bool: - for node in graph.nodes: - for user in node.users: - if node not in user.all_input_nodes: - return False - - for input_node in node.all_input_nodes: - if node not in input_node.users: - return False - - if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len( - set(node.all_input_nodes) - ): - return False - - return True - - -def topological_partition( - graph: GraphModule, target_nodes: Set[Node] -) -> List[List[Node]]: - # use list representation to maintain topological sorting - assert check_assumption(graph.graph) - - partitions: List[List[Node]] = [[]] - remaining_indegrees = { - node: len([node for node in node.all_input_nodes if node.op != "get_attr"]) - for node in graph.graph.nodes - } - partition_index = 0 # global counter - - # start with graph input nodes - queue = deque( - node - for node in graph.graph.nodes - if remaining_indegrees[node] == 0 and node.op != "get_attr" - ) - while len(queue) > 0: - node = queue.popleft() - - # guarantee targets are assigned to disjoint partitions - if node in target_nodes: - partition_index += 1 - partitions.append([]) - - # assign to partition - partitions[partition_index].append(node) - - # recurse on last indegree only in order to guarantee that - # the node is assigned to maximal partition - for user in node.users: - remaining_indegrees[user] -= 1 - if remaining_indegrees[user] == 0: - queue.append(user) - - # a perfect solution would involve implicitly consolodating partition indices so - # that each node is assigned to the maximum partition possible (in order to delay - # execution as long as possible), but this covers the most costly case (get_attr) - for node in graph.graph.nodes: - if node.op == "get_attr": - user_partitions = [] - for user in node.users: - for index in range(len(partitions)): - if user in partitions[index]: - user_partitions.append(index) - break - partition_index = min(user_partitions) - partitions[partition_index].insert(0, node) - - assert set().union(*partitions) == set(graph.graph.nodes) - return partitions - - -def partition_graph(model: torch.nn.Module, partitions: List[List[Node]]): - subgraphs = [] - - # create subgraphs - for partition_nodes in partitions: - # create a new graph for the partition - subgraph = Graph(model) - node_map = {} - - # add placeholders for inputs not in this subgraph. use set to deduplicate - new_input_nodes = { - input_node - for node in partition_nodes - # if node.op != "get_attr" - for input_node in node.all_input_nodes - if input_node not in partition_nodes and input_node.op - } - for input_node in new_input_nodes: - node_map[input_node] = subgraph.placeholder(input_node.name) - - # add the nodes to subgraph - for node in partition_nodes: - node_map[node] = subgraph.node_copy(node, lambda n: node_map[n]) - - # add an output node to collect all subgraph outputs into a dictionary - if len(subgraph.find_nodes(op="output")) <= 0: - output_dict = { - node.name: node_map[node] - for node in partition_nodes - if any(user not in partition_nodes for user in node.users.keys()) - } - subgraph.output(output_dict) - - # Save the subgraph for this partition - subgraph.lint() - input_names = [node.name for node in subgraph.nodes if node.op == "placeholder"] - subgraphs.append( - { - "graph": subgraph, - "code": subgraph.python_code("self"), - "input_names": input_names, - "consumed_names": [], - } - ) - - print([n for n in subgraph.nodes]) - assert check_assumption(subgraph) - - return subgraphs - - -def trace_consumed_names(subgraphs: List[Dict[str, Any]]): - # TODO: update consumed names as new partitions are appended - # populate consumed_names according to when inputs are last used - # in order to vacate the `intermediates` cache and save memory - all_input_names = set().union(*(subgraph["input_names"] for subgraph in subgraphs)) - for input_name in all_input_names: - for subgraph in reversed(subgraphs): - if input_name in subgraph["input_names"]: - subgraph["consumed_names"].append(input_name) - break - else: - assert False - - -class PartitionedModel: - def __init__(self): - self.graph = None - self.subgraphs = [] - self.model = None - - def init_forward( - self, model: torch.nn.Module, targets: List[str], dummy_input: Dict[str, Any] - ): - self.model = model - self.targets = targets - - # 1. trace graph - targets = self.targets - - class CustomTracer(HFTracer): - def is_leaf_module( - self, module: torch.nn.Module, module_qualified_name: str - ) -> bool: - if type(module).__name__ in targets: - return True # Treat as leaf, skip tracing inside this module - return super().is_leaf_module(module, module_qualified_name) - - with HooksMixin.disable_hooks(), calibration_forward_context(self.model): - sig = inspect.signature(self.model.forward) - concrete_args = {} - for parameter in sig.parameters.values(): - if parameter.name in model.dummy_inputs: - continue - if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL: - value = list() - elif parameter.kind == inspect._ParameterKind.VAR_KEYWORD: - value = dict() - elif parameter.name == "use_cache": - value = False - else: - value = parameter.default - - concrete_args[parameter.name] = value - - # self.graph: GraphModule = symbolic_trace(model, disable_check=True, tracer_cls=CustomTracer) - self.graph: GraphModule = torch.fx.GraphModule( - model, - CustomTracer().trace( - model, - dummy_inputs=model.dummy_inputs, - concrete_args=concrete_args, - complete_concrete_args_with_inputs_not_in_dummy_inputs=False, - ), - ) - self.graph.config = model.config - self.graph.class_for_deserialization = model.__class__ - self.graph.device = model.device - self.graph: GraphModule - - # 2. identify target nodes - all_target_nodes = get_target_nodes(self.graph, self.targets) - - # 3. cut into partitions along target nodes - partitions: List[List[Node]] = topological_partition( - self.graph, all_target_nodes - ) - self.subgraphs: List[GraphModule] = partition_graph(self.model, partitions) - - trace_consumed_names(self.subgraphs) - - def forward_data( - self, dataloader, mask_padding: bool = True, run_twice: bool = False - ): - # TODO: give option to skip lm_head - # 4. perform compression - model_device = next(self.model.parameters()).device - batch_intermediates = [ - apply_pad_mask_to_batch(batch) if mask_padding else batch - for batch in dataloader - ] - batch_outputs = [None for _ in range(len(dataloader))] - - for subgraph_index, subgraph in enumerate(self.subgraphs): - code = subgraph["code"] - exec(code.src, code.globals) - forward_function = code.globals.get("forward") - - if run_twice: - for batch_index in range(len(dataloader)): - intermediates = batch_intermediates[batch_index] - inputs = { - input_name: intermediates[input_name] - for input_name in subgraph["input_names"] - } - inputs = tensors_to_device(inputs, model_device) - try: - forward_function(self.model, **inputs) - except EarlyStopException: - pass - - with HooksMixin.disable_hooks() if run_twice else contextlib.nullcontext(): - for batch_index in range(len(dataloader)): - intermediates = batch_intermediates[batch_index] - - inputs = { - input_name: intermediates[input_name] - for input_name in subgraph["input_names"] - } - inputs = tensors_to_device(inputs, model_device) - try: - subgraph_output = forward_function(self.model, **inputs) - except EarlyStopException: - subgraph_output = None - pass - subgraph_output = tensors_to_device(subgraph_output, "cpu") - - for consumed_name in subgraph["consumed_names"]: - del intermediates[consumed_name] - - if subgraph_index < len(self.subgraphs) - 1: - intermediates.update(subgraph_output) - else: - batch_outputs[batch_index] = subgraph_output - - return batch_outputs diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 9aeccd059..96769119c 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -60,7 +60,7 @@ def run_calibration_forward( :param model: PyTorch model to run :param calibration_dataloader: data to use for calibration :param num_calibration_steps: number of items in calibration_dataloader to process, - None or a negative number to process all available data + None or a negative number to process all available data :param calibration_function: option to pass a custom forward function for model :param device: option to move the model to a specific device before calibration :param mask_padding: whether to zero out padding tokens during calibration @@ -100,7 +100,7 @@ def run_calibration_forward( except EarlyStopException as e: # model was stopped early, save last calculated output and # move on to next calibration sample - # intermediates.append((e.args, e.kwargs)) + intermediates.append((e.args, e.kwargs)) pass # TODO: not ideal, figure out where we aren't freeing memory instead diff --git a/src/llmcompressor/pipelines/piecewise/__init__.py b/src/llmcompressor/pipelines/piecewise/__init__.py new file mode 100644 index 000000000..2b0a117ce --- /dev/null +++ b/src/llmcompressor/pipelines/piecewise/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa +from .helpers import * +from .pipeline import run_pipeline diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py new file mode 100644 index 000000000..e09860a94 --- /dev/null +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -0,0 +1,243 @@ +import inspect +from collections import deque +from dataclasses import dataclass +from typing import Any, Dict, List, Set + +from torch.fx import Graph, GraphModule, Node +from torch.nn import Module +from transformers.utils.fx import HFTracer + +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.recipe import Recipe +from llmcompressor.utils.helpers import calibration_forward_context +from llmcompressor.utils.pytorch.module import get_no_split_params + + +@dataclass +class Subgraph: + graph: Graph + input_names: List[str] + consumed_names: List[str] + + +__all__ = ["get_compression_targets", "trace_subgraphs"] + + +def get_compression_targets(model: Module, recipe: Recipe) -> Set[Module]: + """ + TODO: true sequential + + List of modules which are guaranteed to be split into different partitions and + whose inner operations will not be traced + """ + no_split_params = get_no_split_params(model) + return set( + module for module in model.modules() if type(module).__name__ in no_split_params + ) + + +def trace_subgraphs( + model: Module, sample_input: Dict[str, Any], targets: Set[Module] +) -> List[Subgraph]: + # initialize arguments + tracer = get_tracer(targets) + concrete_args = populate_concrete_args(model, sample_input) + + # trace + with calibration_forward_context(model), HooksMixin.disable_hooks(): + graph = GraphModule( + model, + tracer.trace( + model, + dummy_inputs=sample_input, + concrete_args=concrete_args, + complete_concrete_args_with_inputs_not_in_dummy_inputs=False, + # bug in trace throws an error for variadic + # args and kwargs in function signature + ), + ) + + # copy metadata + graph.config = model.config + graph.class_for_deserialization = model.__class__ + graph.device = model.device + + # perform subgraph partition + partitions = topological_partition(graph, targets) + subgraphs = partition_graph(model, partitions) + trace_consumed_names(subgraphs) + + return subgraphs + + +def get_tracer(targets: List[Module]) -> HFTracer: + class PiecewiseTracer(HFTracer): + def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: + if module in targets: + return True # Treat as leaf, skip tracing inside this module + return super().is_leaf_module(module, module_qualified_name) + + return PiecewiseTracer() + + +def populate_concrete_args(model: Module, sample_input: Dict) -> Dict: + sig = inspect.signature(model.forward) + + concrete_args = {} + for parameter in sig.parameters.values(): + if parameter.name in sample_input: + continue + if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL: + value = list() + elif parameter.kind == inspect._ParameterKind.VAR_KEYWORD: + value = dict() + elif parameter.name == "use_cache": + value = False + else: + value = parameter.default + + concrete_args[parameter.name] = value + + return concrete_args + + +def get_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]: + return set( + node + for node in graph.graph.nodes + if node.op == "call_module" and graph.get_submodule(node.target) in targets + ) + + +def check_assumption(graph: Graph) -> bool: + for node in graph.nodes: + for user in node.users: + if node not in user.all_input_nodes: + return False + + for input_node in node.all_input_nodes: + if node not in input_node.users: + return False + + if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len( + set(node.all_input_nodes) + ): + return False + + return True + + +def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List[Node]]: + assert check_assumption(graph.graph) + target_nodes = get_target_nodes(graph, targets) + + partitions: List[List[Node]] = [[]] + remaining_indegrees = { + node: len([node for node in node.all_input_nodes if node.op != "get_attr"]) + for node in graph.graph.nodes + } + partition_index = 0 # global counter + + # start with graph input nodes + queue = deque( + node + for node in graph.graph.nodes + if remaining_indegrees[node] == 0 and node.op != "get_attr" + ) + while len(queue) > 0: + node = queue.popleft() + + # guarantee targets are assigned to disjoint partitions + if node in target_nodes: + partition_index += 1 + partitions.append([]) + + # assign to partition + partitions[partition_index].append(node) + + # recurse on last indegree only in order to guarantee that + # the node is assigned to maximal partition + for user in node.users: + remaining_indegrees[user] -= 1 + if remaining_indegrees[user] == 0: + queue.append(user) + + # a perfect solution would involve implicitly consolodating partition indices so + # that each node is assigned to the maximum partition possible (in order to delay + # execution as long as possible), but this covers the most costly case (get_attr) + for node in graph.graph.nodes: + if node.op == "get_attr": + user_partitions = [] + for user in node.users: + for index in range(len(partitions)): + if user in partitions[index]: + user_partitions.append(index) + break + partition_index = min(user_partitions) + partitions[partition_index].insert(0, node) + + assert set().union(*partitions) == set(graph.graph.nodes) + return partitions + + +def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgraph]: + subgraphs = [] + + # create subgraphs + for partition_nodes in partitions: + # create a new graph for the partition + graph = Graph(model) + node_map = {} + + # add placeholders for inputs not in this subgraph. use set to deduplicate + new_input_nodes = { + input_node + for node in partition_nodes + # if node.op != "get_attr" + for input_node in node.all_input_nodes + if input_node not in partition_nodes and input_node.op + } + for input_node in new_input_nodes: + node_map[input_node] = graph.placeholder(input_node.name) + + # add the nodes to subgraph + for node in partition_nodes: + node_map[node] = graph.node_copy(node, lambda n: node_map[n]) + + # add an output node to collect all subgraph outputs into a dictionary + if len(graph.find_nodes(op="output")) <= 0: + output_dict = { + node.name: node_map[node] + for node in partition_nodes + if any(user not in partition_nodes for user in node.users.keys()) + } + graph.output(output_dict) + + # Save the subgraph for this partition + graph.lint() + input_names = [node.name for node in graph.nodes if node.op == "placeholder"] + subgraphs.append( + Subgraph( + graph=graph, + input_names=input_names, + consumed_names=[], # populated later + ) + ) + + assert check_assumption(graph) + + return subgraphs + + +def trace_consumed_names(subgraphs: List[Dict[str, Any]]): + # TODO: update consumed names as new partitions are appended + # populate consumed_names according to when inputs are last used + # in order to vacate the `intermediates` cache and save memory + all_input_names = set().union(*(subgraph.input_names for subgraph in subgraphs)) + for input_name in all_input_names: + for subgraph in reversed(subgraphs): + if input_name in subgraph.input_names: + subgraph.consumed_names.append(input_name) + break + else: + assert False diff --git a/src/llmcompressor/pipelines/peicewise.py b/src/llmcompressor/pipelines/piecewise/pipeline.py similarity index 55% rename from src/llmcompressor/pipelines/peicewise.py rename to src/llmcompressor/pipelines/piecewise/pipeline.py index c4e94e26b..a2cc0ea76 100644 --- a/src/llmcompressor/pipelines/peicewise.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -1,60 +1,59 @@ -import contextlib +from contextlib import nullcontext import torch -from datasets import Dataset +import torch.utils.data.dataloader +import tqdm -from llmcompressor.core.session_functions import initialize from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException -from llmcompressor.recipe.recipe import Recipe -from llmcompressor.utils.helpers import ( - calibration_forward_context, - create_dataloader, - get_model_device, - get_targets, - tensors_to_device, +from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pipelines.piecewise.helpers import ( + get_compression_targets, trace_subgraphs, ) +from llmcompressor.pytorch.utils.helpers import tensors_to_device +from llmcompressor.recipe.recipe import Recipe +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["run_pipeline"] def run_pipeline( model: torch.nn.Module, recipe: Recipe, - dataset: Dataset, + dataloader: torch.utils.data.DataLoader, propagate_error: bool, ): # trace subgraphs - targets = get_targets(recipe) - sample_input_names = next(iter(dataset)).keys() - subgraphs = trace_subgraphs(model, sample_input_names, targets) + sample_input = next(iter(dataloader)) + targets = get_compression_targets(model, recipe) + subgraphs = trace_subgraphs(model, sample_input, targets) - # apply recipe to model - initialize(recipe, model) + # FUTURE: apply recipe to model + # initialize(recipe, model) - # create dataloader - model_device = get_model_device(model) - dataloader = create_dataloader( - dataset, batch_size=..., mask_padding=True, model_device=model_device - ) + model_device = next(model.parameters()).device with calibration_forward_context(model): # prepare intermediates cache - batch_intermediates = list(iter(dataloader)) + batch_intermediates = [ + apply_pad_mask_to_batch(batch) for batch in iter(dataloader) + ] batch_outputs = [None for _ in range(len(dataloader))] for subgraph_index, subgraph in enumerate(subgraphs): # compile subgraph forward function - code = subgraph["code"] + code = subgraph.graph.python_code("self") exec(code.src, code.globals) forward_function = code.globals.get("forward") if propagate_error: # do an preliminary pass to trigger modifier hooks - for batch_index in range(len(dataloader)): + desc = f"(Partition {subgraph_index}): Uncompressed forward" + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): intermediates = batch_intermediates[batch_index] inputs = { input_name: intermediates[input_name] - for input_name in subgraph["input_names"] + for input_name in subgraph.input_names } inputs = tensors_to_device(inputs, model_device) forward_function(model, **inputs) @@ -62,19 +61,24 @@ def run_pipeline( # if using propagate_error, then this pass does not trigger modifier hooks # and is only used for capturing intermediates # otherwise, this pass triggers modifier hooks and captures intermediates - with HooksMixin.disable_hooks() if propagate_error else contextlib.nullcontext(): - for batch_index in range(len(dataloader)): + with HooksMixin.disable_hooks() if propagate_error else nullcontext(): + desc = ( + f"(Partition {subgraph_index}): Compressed forward" + if propagate_error + else f"(Partition {subgraph_index}): Uncompressed forward" + ) + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): intermediates = batch_intermediates[batch_index] inputs = { input_name: intermediates[input_name] - for input_name in subgraph["input_names"] + for input_name in subgraph.input_names } inputs = tensors_to_device(inputs, model_device) subgraph_output = forward_function(model, **inputs) subgraph_output = tensors_to_device(subgraph_output, "cpu") - for consumed_name in subgraph["consumed_names"]: + for consumed_name in subgraph.consumed_names: del intermediates[consumed_name] if subgraph_index < len(subgraphs) - 1: diff --git a/src/llmcompressor/pytorch/__init__.py b/src/llmcompressor/pytorch/__init__.py index 869f83f04..babb9af4c 100644 --- a/src/llmcompressor/pytorch/__init__.py +++ b/src/llmcompressor/pytorch/__init__.py @@ -7,39 +7,39 @@ from packaging import version -# try: -# import torch - -# _PARSED_TORCH_VERSION = version.parse(torch.__version__) - -# if _PARSED_TORCH_VERSION.major >= 2: -# torch_compile_func = torch.compile - -# def raise_torch_compile_warning(*args, **kwargs): -# warnings.warn( -# "torch.compile is not supported by llmcompressor for torch 2.0.x" -# ) -# return torch_compile_func(*args, **kwargs) - -# torch.compile = raise_torch_compile_warning - -# _BYPASS = bool(int(os.environ.get("NM_BYPASS_TORCH_VERSION", "0"))) -# if _PARSED_TORCH_VERSION.major == 1 and _PARSED_TORCH_VERSION.minor in [10, 11]: -# if not _BYPASS: -# raise RuntimeError( -# "llmcompressor does not support torch==1.10.* or 1.11.*. " -# f"Found torch version {torch.__version__}.\n\n" -# "To bypass this error, set environment variable " -# "`NM_BYPASS_TORCH_VERSION` to '1'.\n\n" -# "Bypassing may result in errors or " -# "incorrect behavior, so set at your own risk." -# ) -# else: -# warnings.warn( -# "llmcompressor quantized onnx export does not work " -# "with torch==1.10.* or 1.11.*" -# ) -# except ImportError: -# pass +try: + import torch + + _PARSED_TORCH_VERSION = version.parse(torch.__version__) + + if _PARSED_TORCH_VERSION.major >= 2: + torch_compile_func = torch.compile + + def raise_torch_compile_warning(*args, **kwargs): + warnings.warn( + "torch.compile is not supported by llmcompressor for torch 2.0.x" + ) + return torch_compile_func(*args, **kwargs) + + torch.compile = raise_torch_compile_warning + + _BYPASS = bool(int(os.environ.get("NM_BYPASS_TORCH_VERSION", "0"))) + if _PARSED_TORCH_VERSION.major == 1 and _PARSED_TORCH_VERSION.minor in [10, 11]: + if not _BYPASS: + raise RuntimeError( + "llmcompressor does not support torch==1.10.* or 1.11.*. " + f"Found torch version {torch.__version__}.\n\n" + "To bypass this error, set environment variable " + "`NM_BYPASS_TORCH_VERSION` to '1'.\n\n" + "Bypassing may result in errors or " + "incorrect behavior, so set at your own risk." + ) + else: + warnings.warn( + "llmcompressor quantized onnx export does not work " + "with torch==1.10.* or 1.11.*" + ) +except ImportError: + pass # # flake8: noqa diff --git a/src/llmcompressor/pytorch/tracing/llava.py b/src/llmcompressor/pytorch/tracing/llava.py index e4d95c786..23f7a6bbb 100644 --- a/src/llmcompressor/pytorch/tracing/llava.py +++ b/src/llmcompressor/pytorch/tracing/llava.py @@ -1,5 +1,6 @@ -from typing import Optional, List, Union, Tuple +# flake8: noqa from functools import wraps +from typing import List, Optional, Tuple, Union import torch from transformers import LlavaForConditionalGeneration diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 48e5face6..ea3c71e7b 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,7 +1,6 @@ from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Union -import torch from compressed_tensors.registry import RegistryMixin from datasets import Dataset, IterableDataset from datasets.formatting.formatting import LazyRow diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 7106b64b1..c7ac2aa1b 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -222,6 +222,7 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: ) +# TODO: define in compressed tensors # depreciation candidate @wraps(has_offloaded_params) def is_module_offloaded(module: torch.nn.Module) -> bool: @@ -305,7 +306,8 @@ def update_offload_parameter( if data is not None: if data.device == "meta": raise ValueError( - "Cannot copy data from meta device. Consider calling with align_module(module) context" + "Cannot copy data from meta device. Consider calling with " + "align_module(module) context" ) if param.data.dtype != data.dtype: @@ -318,7 +320,8 @@ def update_offload_parameter( if has_offloaded_params(module): weights_map = module._hf_hook.weights_map - # for upstreaming, probably better to modify the weight map types so that they can be written to? + # for upstreaming, probably better to modify the weight + # map types so that they can be written to? if isinstance(weights_map, PrefixedDataset): prefix_dict = getattr_chain( module, "module._hf_hook.weights_map.dataset", None @@ -366,7 +369,8 @@ def align_module( module's execution device within the context. Yields: - None: Yields control while the module's parameters are aligned to the execution device. + None: Yields control while the module's parameters are + aligned to the execution device. """ if has_offloaded_params(module): if execution_device is not None: @@ -437,7 +441,8 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): if has_offloaded_params(module): weights_map = module._hf_hook.weights_map - # for upstreaming, probably better to modify the weight map types so that they can be written to? + # for upstreaming, probably better to modify the + # weight map types so that they can be written to? if isinstance(weights_map, PrefixedDataset): dataset = weights_map.dataset prefix = weights_map.prefix From 42b5fc07820b275025e003bf89507c009a0423af Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 03:30:23 +0000 Subject: [PATCH 153/285] disable_hf_hook context --- .../modifiers/quantization/gptq/base.py | 3 +- .../pipelines/piecewise/helpers.py | 8 +++-- src/llmcompressor/pytorch/__init__.py | 2 +- src/llmcompressor/utils/helpers.py | 29 +++++++++++++++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 77cb56a7b..4792ab0ce 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -291,7 +291,8 @@ def _maybe_onload_hessians(self, module: torch.nn.Module): yield if self.offload_hessians: - self._hessians[module] = self._hessians[module].to(device="cpu") + if module in self._hessians: # may have been deleted in context + self._hessians[module] = self._hessians[module].to(device="cpu") def _build_quant_modifier(self): """ diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py index e09860a94..cc54a5b3d 100644 --- a/src/llmcompressor/pipelines/piecewise/helpers.py +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -9,7 +9,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.recipe import Recipe -from llmcompressor.utils.helpers import calibration_forward_context +from llmcompressor.utils.helpers import calibration_forward_context, disable_hf_hook from llmcompressor.utils.pytorch.module import get_no_split_params @@ -44,7 +44,11 @@ def trace_subgraphs( concrete_args = populate_concrete_args(model, sample_input) # trace - with calibration_forward_context(model), HooksMixin.disable_hooks(): + with ( + calibration_forward_context(model), + HooksMixin.disable_hooks(), + disable_hf_hook(model, recurse=True), + ): graph = GraphModule( model, tracer.trace( diff --git a/src/llmcompressor/pytorch/__init__.py b/src/llmcompressor/pytorch/__init__.py index babb9af4c..66d4be1b4 100644 --- a/src/llmcompressor/pytorch/__init__.py +++ b/src/llmcompressor/pytorch/__init__.py @@ -42,4 +42,4 @@ def raise_torch_compile_warning(*args, **kwargs): except ImportError: pass -# # flake8: noqa +# flake8: noqa diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index f1597c279..e3aa644d0 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -24,6 +24,8 @@ import numpy import torch +from accelerate.hooks import add_hook_to_module, remove_hook_from_module +from accelerate.utils import PrefixedDataset from compressed_tensors import is_module_offloaded from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger @@ -1148,3 +1150,30 @@ def align_module(module: torch.nn.Module, device: Optional[torch.device] = None) else: yield + + +@contextlib.contextmanager +def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): + hook = module._hf_hook + prefix_dict = module._hf_hook.weights_map + new_prefix = {} + + # recreate the prefix dict (since it is immutable) + # and add quantization parameters + if prefix_dict is not None: + for key, data in module.named_parameters(): + if key not in prefix_dict: + new_prefix[f"{prefix_dict.prefix}{key}"] = data + else: + new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key] + prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix) + + # removing the AlignDevicesHook also moves the module to the original device + remove_hook_from_module(module, recurse=recurse) + + yield + + # we need to re-add the hook for offloading now that we've wrapped forward + add_hook_to_module(module, hook) + if prefix_dict is not None: + module._hf_hook.weights_map = prefix_dict From bc33e8eb14c3de9bb3b44f970a33eade000d7907 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 06:45:23 +0000 Subject: [PATCH 154/285] woof --- .../pipelines/piecewise/pipeline.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index a2cc0ea76..05e220078 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -40,7 +40,12 @@ def run_pipeline( ] batch_outputs = [None for _ in range(len(dataloader))] - for subgraph_index, subgraph in enumerate(subgraphs): + num_subgraphs = len(subgraphs) + for index, subgraph in enumerate(subgraphs): + # prepare tqdm description texts + unc_desc = f"(Partition {index + 1}/{num_subgraphs}): Uncompressed forward" + comp_desc = f"(Partition {index + 1}/{num_subgraphs}): Compressed forward" + # compile subgraph forward function code = subgraph.graph.python_code("self") exec(code.src, code.globals) @@ -48,8 +53,7 @@ def run_pipeline( if propagate_error: # do an preliminary pass to trigger modifier hooks - desc = f"(Partition {subgraph_index}): Uncompressed forward" - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=unc_desc): intermediates = batch_intermediates[batch_index] inputs = { input_name: intermediates[input_name] @@ -62,11 +66,7 @@ def run_pipeline( # and is only used for capturing intermediates # otherwise, this pass triggers modifier hooks and captures intermediates with HooksMixin.disable_hooks() if propagate_error else nullcontext(): - desc = ( - f"(Partition {subgraph_index}): Compressed forward" - if propagate_error - else f"(Partition {subgraph_index}): Uncompressed forward" - ) + desc = comp_desc if propagate_error else unc_desc for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): intermediates = batch_intermediates[batch_index] @@ -81,7 +81,7 @@ def run_pipeline( for consumed_name in subgraph.consumed_names: del intermediates[consumed_name] - if subgraph_index < len(subgraphs) - 1: + if index < len(subgraphs) - 1: intermediates.update(subgraph_output) else: batch_outputs[batch_index] = subgraph_output From ca72bbb8381460be8f6dc6ddf4d250f978285afa Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 06:56:45 +0000 Subject: [PATCH 155/285] change desc --- src/llmcompressor/pipelines/piecewise/pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index 05e220078..952dafdb2 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -43,8 +43,8 @@ def run_pipeline( num_subgraphs = len(subgraphs) for index, subgraph in enumerate(subgraphs): # prepare tqdm description texts - unc_desc = f"(Partition {index + 1}/{num_subgraphs}): Uncompressed forward" - comp_desc = f"(Partition {index + 1}/{num_subgraphs}): Compressed forward" + uncomp_desc = f"({index + 1}/{num_subgraphs}): Calibrating" + comp_desc = f"({index + 1}/{num_subgraphs}): Propagate" # compile subgraph forward function code = subgraph.graph.python_code("self") @@ -53,7 +53,7 @@ def run_pipeline( if propagate_error: # do an preliminary pass to trigger modifier hooks - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=unc_desc): + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=uncomp_desc): intermediates = batch_intermediates[batch_index] inputs = { input_name: intermediates[input_name] @@ -66,7 +66,7 @@ def run_pipeline( # and is only used for capturing intermediates # otherwise, this pass triggers modifier hooks and captures intermediates with HooksMixin.disable_hooks() if propagate_error else nullcontext(): - desc = comp_desc if propagate_error else unc_desc + desc = comp_desc if propagate_error else uncomp_desc for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): intermediates = batch_intermediates[batch_index] From 293640a17976bb0cc8563234af4fdec18b58d09d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 07:03:29 +0000 Subject: [PATCH 156/285] fix docstring --- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 7620dcb9a..a6f3e7d48 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -65,15 +65,14 @@ def quantize_weight( blocksize: int = 128, percdamp: float = 0.01, module_class: Type[torch.nn.Module] = torch.nn.Linear, - weight_original: Optional[torch.Tensor] = None, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: """ Quantize a module weight according to the GPTQ algorithm - TODO :param weight: weight being quantized - :param inp: module inputs used to calculate hessian :param quant_args: quantization arguments used to find quantization parameters + :param hessian: preaccumulated hessian for quantization + :param inp: module inputs used to calculate hessian. Incompatible with `hessian` arg :param blocksize: chunk size of quantization updates :param percdamp: dampening factor on hessian diagonal :param module_class: class of module, likely torch.nn.Linear @@ -92,9 +91,6 @@ def quantize_weight( averaging_constant=1.0, # ignore moving average ) - if weight_original is not None: - raise NotImplementedError() - # standardize shape and dtype if module_class == torch.nn.Conv2d: W = W.flatten(1) From 17b3a70e267addee4f2875eb92cb7cf878f50c4d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 08:14:29 +0000 Subject: [PATCH 157/285] rely on compressed tensors, support offloading --- .../modifiers/quantization/gptq/base.py | 18 +- .../pipelines/piecewise/helpers.py | 4 +- .../pipelines/piecewise/pipeline.py | 7 +- src/llmcompressor/utils/fsdp/helpers.py | 276 ------------------ src/llmcompressor/utils/helpers.py | 68 +---- 5 files changed, 22 insertions(+), 351 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 4792ab0ce..6362e0d16 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -5,6 +5,12 @@ import torch from compressed_tensors.quantization import QuantizationScheme +from compressed_tensors.utils import ( + align_module_device, + get_execution_device, + getattr_chain, + update_offload_parameter, +) from loguru import logger from pydantic import Field, PrivateAttr, field_validator @@ -19,8 +25,6 @@ from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.piecewise import run_pipeline -from llmcompressor.utils.fsdp.helpers import update_offload_parameter -from llmcompressor.utils.helpers import align_module, getattr_chain from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import qat_active @@ -240,7 +244,9 @@ def compress_module( # Initialize hessian if not present if module not in self._num_samples: - init_device = "cpu" if self.offload_hessians else module.weight.device + init_device = ( + "cpu" if self.offload_hessians else get_execution_device(module) + ) self._hessians[module] = make_empty_hessian(module, device=init_device) self._num_samples[module] = 0 @@ -257,7 +263,7 @@ def compress_module( if self._num_samples[module] >= self.update_size: logger.info(f"Quantizing {name} using {self._num_samples[module]} samples") with ( - align_module(module), + align_module_device(module), self._maybe_onload_hessians(module), CompressionLogger(module) as comp_logger, ): @@ -271,7 +277,7 @@ def compress_module( ) module.weight += quantized_weight - module.weight # Future: FSDP - update_offload_parameter(module, "weight") + update_offload_parameter(module, "weight", module.weight.data) update_offload_parameter(module, "weight_scale", scale) update_offload_parameter(module, "weight_zero_point", zero_point) if g_idx is not None: @@ -285,7 +291,7 @@ def compress_module( @contextlib.contextmanager def _maybe_onload_hessians(self, module: torch.nn.Module): if self.offload_hessians: - device = module.weight.device + device = get_execution_device(module) self._hessians[module] = self._hessians[module].to(device=device) yield diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py index cc54a5b3d..80166f292 100644 --- a/src/llmcompressor/pipelines/piecewise/helpers.py +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -3,13 +3,14 @@ from dataclasses import dataclass from typing import Any, Dict, List, Set +from compressed_tensors.utils import disable_hf_hook from torch.fx import Graph, GraphModule, Node from torch.nn import Module from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.recipe import Recipe -from llmcompressor.utils.helpers import calibration_forward_context, disable_hf_hook +from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import get_no_split_params @@ -151,6 +152,7 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List while len(queue) > 0: node = queue.popleft() + # TODO: test swapping with below # guarantee targets are assigned to disjoint partitions if node in target_nodes: partition_index += 1 diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index 952dafdb2..142a7ce11 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -31,7 +31,12 @@ def run_pipeline( # FUTURE: apply recipe to model # initialize(recipe, model) - model_device = next(model.parameters()).device + # TODO: revisit + device_map = getattr(model, "hf_device_map", None) + if device_map is not None: + model_device = next(iter(device_map.values())) + else: + model_device = model.device with calibration_forward_context(model): # prepare intermediates cache diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index c7ac2aa1b..3a3248fa5 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,14 +1,9 @@ -import contextlib import operator -import warnings -from functools import wraps from pathlib import Path from typing import Optional from loguru import logger -from llmcompressor.utils.helpers import getattr_chain - try: from torch.distributed.fsdp import ( FullStateDictConfig, @@ -26,21 +21,6 @@ from llmcompressor.typing import Processor from llmcompressor.utils.pytorch import set_layer -try: - from accelerate.hooks import AlignDevicesHook - from accelerate.utils import ( - OffloadedWeightsLoader, - PrefixedDataset, - set_module_tensor_to_device, - ) - - _has_accelerate = True -except ImportError: - _has_accelerate = False - AlignDevicesHook = None - OffloadedWeightsLoader = None - PrefixedDataset = None - __all__ = [ "is_fsdp_model", "maybe_get_wrapped", @@ -200,259 +180,3 @@ def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: parent = operator.attrgetter(parent_name)(model) return parent - - -# upstream candidate -def has_offloaded_params(module: torch.nn.Module) -> bool: - """ - Checks if a module has offloaded parameters by checking if the given module - has a AlignDevicesHook attached with offloading enabled - - Args: - module (`torch.nn.Module`): The module to check for an offload hook. - - Returns: - bool: `True` if the module has an offload hook and offloading is enabled, - `False` otherwise. - """ - return ( - hasattr(module, "_hf_hook") - and isinstance(module._hf_hook, AlignDevicesHook) - and module._hf_hook.offload - ) - - -# TODO: define in compressed tensors -# depreciation candidate -@wraps(has_offloaded_params) -def is_module_offloaded(module: torch.nn.Module) -> bool: - if not _has_accelerate: - return False - - return has_offloaded_params(module) - - -# depreciation candidate -def get_execution_device(module: torch.nn.Module) -> torch.device: - """ - :param module: module to check - :return: device module is loaded onto during forward pass - """ - if is_module_offloaded(module): - return module._hf_hook.execution_device - device = next(module.parameters()).device - - # offload only gets set for leaf modules, fallback to checking for device type - if device.type == "meta": - return module._hf_hook.execution_device - - return device - - -# upstream candidate -def _infer_offload_device(module: torch.nn.Module) -> torch.device: - if not has_offloaded_params(module): - raise ValueError("Cannot infer offload device from non-offloaded module") - - first_key = next(module._hf_hook.weights_map.keys(), None) - if first_key is None: - raise ValueError("Cannot infer offload device from empty weights map") - - prefix_dataset = module._hf_hook.weights_map.dataset - return prefix_dataset[first_key].device - - -# depreciation candidate -def get_offloaded_device(module: torch.nn.Module) -> torch.device: - """ - :param module: module to check - :return: device module is offloaded to onto after forward pass - """ - return _infer_offload_device(module) - - -# depreciation candidate -def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): - """ - Updates the offloaded state dict for a given module. Parameter named key is replaced - by data. This is neccesary because parameter updates for offloaded modules do not - persist automatically between loads. This function only affects the offloaded - state dict and not the current state of the loaded module. - - :param module: module containing the parameter to update - :param key: name of parameter to update - :param data: tensor to update parameter with in the offloaded state dict - """ - if not is_module_offloaded(module): - raise ValueError("Prefix dict is only applicable to offloaded modules") - prefix_dict = module._hf_hook.weights_map - prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data - - -# upstream candidate? -def update_offload_parameter( - module: torch.nn.Module, - name: str, - data: Optional[torch.Tensor] = None, - offload_device: Optional[torch.device] = None, -): - """ - :param module: module containing the parameter to update - :param name: name of module parameter to update - :param data: tensor to update parameter with - :param offload_device: offload device for newly registered parameters - """ - param = getattr(module, name) - if data is not None: - if data.device == "meta": - raise ValueError( - "Cannot copy data from meta device. Consider calling with " - "align_module(module) context" - ) - - if param.data.dtype != data.dtype: - print(name) - print((param.data.dtype, data.dtype)) - warnings.warn("TODO") - - param.data.copy_(data) - - if has_offloaded_params(module): - weights_map = module._hf_hook.weights_map - - # for upstreaming, probably better to modify the weight - # map types so that they can be written to? - if isinstance(weights_map, PrefixedDataset): - prefix_dict = getattr_chain( - module, "module._hf_hook.weights_map.dataset", None - ) - if prefix_dict is not None: - prefix = module._hf_hook.weights_map.prefix - key = f"{prefix}{name}" - - offload_device = ( - prefix_dict[key].device - if key in prefix_dict - else offload_device - if offload_device is not None - else _infer_offload_device(module) - ) - prefix_dict[key] = param.data.to(device=offload_device) - - if isinstance(weights_map, OffloadedWeightsLoader): - raise NotImplementedError() - - else: - raise NotImplementedError() - - -# depreciation candidate -def update_parameter_data( - module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str -): - param = getattr(module, param_name) - new_param_data = new_param_data.to(device=param.device, dtype=param.dtype) - update_offload_parameter(module, param_name, new_param_data) - - -# upstream candidate -@contextlib.contextmanager -def align_module( - module: torch.nn.Module, execution_device: Optional[torch.device] = None -): - """ - Moves a module's parameters to the specified execution device. - - Args: - module (torch.nn.Module): Module with parameters to align. - execution_device (Optional[torch.device]): If provided, overrides the - module's execution device within the context. - - Yields: - None: Yields control while the module's parameters are - aligned to the execution device. - """ - if has_offloaded_params(module): - if execution_device is not None: - original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = execution_device - - module._hf_hook.pre_forward(module) - yield - module._hf_hook.post_forward(module, None) - - if execution_device is not None: - module._hf_hook.execution_device = original_device - - elif execution_device is not None: - devices = {} - for name, param in module.named_parameters(): - devices[name] = param.device - set_module_tensor_to_device( - module, - name, - execution_device, - ) - - yield - - for name, param in module.named_parameters(): - set_module_tensor_to_device( - module, - name, - devices[name], - ) - - else: - yield - - -@contextlib.contextmanager -def modify_offload_module( - module: torch.nn.Module, - execution_device: Optional[torch.device] = None, - offload_device: Optional[torch.device] = None, -): - with align_module(module, execution_device): - yield - - # there is little performance gain from checking if a parameter's data - # has been modified before copying since the new data must be copied - # to the offload device anyways; just update all module parameters - for name, param in module.named_parameters(): - update_offload_parameter(module, name, param.data, offload_device) - - -# upstream candidate? -def register_offload_parameter( - module: torch.nn.Module, - name: str, - parameter: torch.nn.Parameter, - offload_device: Optional[torch.device] = None, -): - module.register_parameter(name, parameter) - update_offload_parameter(module, name, parameter.data, offload_device) - - -# upstream candidate? -def delete_offload_parameter(module: torch.nn.Module, name: str): - delattr(module, name) - - if has_offloaded_params(module): - weights_map = module._hf_hook.weights_map - - # for upstreaming, probably better to modify the - # weight map types so that they can be written to? - if isinstance(weights_map, PrefixedDataset): - dataset = weights_map.dataset - prefix = weights_map.prefix - if dataset is not None: - del dataset[f"{prefix}{name}"] - - elif isinstance(weights_map, OffloadedWeightsLoader): - raise NotImplementedError() - - elif weights_map is not None: - raise NotImplementedError( - f"Cannot delete parameter from weights_map of type {type(weights_map)}" - ) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index e3aa644d0..4cbf42a91 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -19,14 +19,11 @@ from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union from urllib.parse import urlparse import numpy import torch -from accelerate.hooks import add_hook_to_module, remove_hook_from_module -from accelerate.utils import PrefixedDataset -from compressed_tensors import is_module_offloaded from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger @@ -1114,66 +1111,3 @@ def calibration_forward_context(model: torch.nn.Module): DisableQuantization(model), ): yield - - -@contextlib.contextmanager -def align_module(module: torch.nn.Module, device: Optional[torch.device] = None): - """ - Move an offloaded module's parameters to device or module execution device - - :param module: module with parameters to align - :param device: optional device to move parameters to, if None is provided then - module execution device will be used - """ - if is_module_offloaded(module): - if device is not None: - original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = device - - module._hf_hook.pre_forward(module) - yield - module._hf_hook.post_forward(module, torch.tensor([])) - - if device is not None: - module._hf_hook.execution_device = original_device - - elif device is not None: - devices = {} - for name, param in module.named_parameters(recurse=False): - devices[name] = param.device - setattr(module, name, param.to(device)) - - yield - - for name, param_device in module.named_parameters: - setattr(module, name, param.to(param_device)) - - else: - yield - - -@contextlib.contextmanager -def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): - hook = module._hf_hook - prefix_dict = module._hf_hook.weights_map - new_prefix = {} - - # recreate the prefix dict (since it is immutable) - # and add quantization parameters - if prefix_dict is not None: - for key, data in module.named_parameters(): - if key not in prefix_dict: - new_prefix[f"{prefix_dict.prefix}{key}"] = data - else: - new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key] - prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix) - - # removing the AlignDevicesHook also moves the module to the original device - remove_hook_from_module(module, recurse=recurse) - - yield - - # we need to re-add the hook for offloading now that we've wrapped forward - add_hook_to_module(module, hook) - if prefix_dict is not None: - module._hf_hook.weights_map = prefix_dict From 5e185f20aed5c28087dce66d74745676bb7d07c8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 17:14:17 +0000 Subject: [PATCH 158/285] sequential targets --- .../modifiers/quantization/gptq/base.py | 12 ++++++++++-- src/llmcompressor/pipelines/piecewise/helpers.py | 13 ++++--------- src/llmcompressor/pipelines/piecewise/pipeline.py | 8 ++++---- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 6362e0d16..96c9f9da4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -26,7 +26,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.piecewise import run_pipeline from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import qat_active +from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active __all__ = ["GPTQModifier"] @@ -204,7 +204,15 @@ def on_initialize(self, state: "State", **kwargs) -> bool: post_hook = partial(self.compress_module, name) self.register_hook(module, post_hook, "forward") - run_pipeline(state.model, None, state.data.calib, propagate_error=True) + # infer sequential targets + if self.sequential_targets is None: + self.sequential_targets = get_no_split_params(state.model) + elif isinstance(self.sequential_targets, str): + self.sequential_targets = [self.sequential_targets] + + run_pipeline( + state.model, self.sequential_targets, state.data.calib, propagate_error=True + ) return True diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py index 80166f292..e558f65a6 100644 --- a/src/llmcompressor/pipelines/piecewise/helpers.py +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -9,9 +9,7 @@ from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.recipe import Recipe from llmcompressor.utils.helpers import calibration_forward_context -from llmcompressor.utils.pytorch.module import get_no_split_params @dataclass @@ -21,20 +19,17 @@ class Subgraph: consumed_names: List[str] -__all__ = ["get_compression_targets", "trace_subgraphs"] +__all__ = ["infer_sequential_targets", "trace_subgraphs"] -def get_compression_targets(model: Module, recipe: Recipe) -> Set[Module]: +def infer_sequential_targets(model: Module, targets: List[str]) -> Set[Module]: """ - TODO: true sequential + Future: infer from recipe List of modules which are guaranteed to be split into different partitions and whose inner operations will not be traced """ - no_split_params = get_no_split_params(model) - return set( - module for module in model.modules() if type(module).__name__ in no_split_params - ) + return set(module for module in model.modules() if type(module).__name__ in targets) def trace_subgraphs( diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index 142a7ce11..853ad22aa 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -1,4 +1,5 @@ from contextlib import nullcontext +from typing import List import torch import torch.utils.data.dataloader @@ -7,11 +8,10 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.piecewise.helpers import ( - get_compression_targets, + infer_sequential_targets, trace_subgraphs, ) from llmcompressor.pytorch.utils.helpers import tensors_to_device -from llmcompressor.recipe.recipe import Recipe from llmcompressor.utils.helpers import calibration_forward_context __all__ = ["run_pipeline"] @@ -19,13 +19,13 @@ def run_pipeline( model: torch.nn.Module, - recipe: Recipe, + targets: List[str], # future: replace with recipe dataloader: torch.utils.data.DataLoader, propagate_error: bool, ): # trace subgraphs sample_input = next(iter(dataloader)) - targets = get_compression_targets(model, recipe) + targets = infer_sequential_targets(model, targets) subgraphs = trace_subgraphs(model, sample_input, targets) # FUTURE: apply recipe to model From 4d8218044af6964509f46948d0c90ff90d0c57ea Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 17:20:04 +0000 Subject: [PATCH 159/285] support match_layers_params --- src/llmcompressor/modifiers/quantization/gptq/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 96c9f9da4..93ecf6d92 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -26,7 +26,11 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.piecewise import run_pipeline from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active +from llmcompressor.utils.pytorch.module import ( + get_layers, + get_no_split_params, + qat_active, +) __all__ = ["GPTQModifier"] @@ -208,7 +212,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if self.sequential_targets is None: self.sequential_targets = get_no_split_params(state.model) elif isinstance(self.sequential_targets, str): - self.sequential_targets = [self.sequential_targets] + self.sequential_targets = get_layers(self.sequential_targets, self.model) run_pipeline( state.model, self.sequential_targets, state.data.calib, propagate_error=True From 6a1b2c2fa3974f3e7ce5d28cf8903e6c5ff023c6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 18:07:35 +0000 Subject: [PATCH 160/285] make _update_size private and inferred --- src/llmcompressor/core/events/event.py | 3 +++ src/llmcompressor/core/session_functions.py | 4 ++++ .../modifiers/quantization/gptq/base.py | 17 ++++++++++------- .../pipelines/piecewise/pipeline.py | 5 ++++- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/core/events/event.py b/src/llmcompressor/core/events/event.py index 9d5d48d63..a31301189 100644 --- a/src/llmcompressor/core/events/event.py +++ b/src/llmcompressor/core/events/event.py @@ -45,6 +45,7 @@ class EventType(Enum): # batch lifecycle BATCH_START = "batch_start" LOSS_CALCULATED = "loss_calculated" + SUB_BATCH_END = "sub_batch_end" BATCH_END = "batch_end" # step lifecycle @@ -74,6 +75,8 @@ def order(self) -> int: return 120 elif self == EventType.OPTIM_POST_STEP: return 130 + elif self == EventType.SUB_BATCH_END: + return 135 elif self == EventType.BATCH_END: return 140 else: diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index 9a123a030..f08b314e8 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -268,6 +268,10 @@ def optim_post_step(cls, **kwargs) -> ModifiedState: :return: the modified state of the active session after invoking the event """ return cls.event(EventType.OPTIM_POST_STEP, **kwargs) + + @classmethod + def sub_batch_end(cls, **kwargs) -> ModifiedState: + cls.event(EventType.SUB_BATCH_END, **kwargs) @classmethod def batch_end(cls, **kwargs) -> ModifiedState: diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 93ecf6d92..11e3f8440 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -71,10 +71,9 @@ class GPTQModifier(Modifier, HooksMixin): | actorder: False - :param sequential_update: Whether or not to update weights sequentially by layer. - This option is depreciated and setting to False is no longer supported :param sequential_targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model + :param block_size: Used to determine number of columns to compress in one pass :param quantize: Set to True to quantize using an existing quantization modifier, or pass in the configuration for a quantization modifier if one does not @@ -104,7 +103,6 @@ class GPTQModifier(Modifier, HooksMixin): # gptq modifier arguments sequential_update: bool = True # DEPRECIATED - update_size: int = 512 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 @@ -123,6 +121,7 @@ class GPTQModifier(Modifier, HooksMixin): _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) + _update_size: Optional[int] = PrivateAttr(default=None) @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -214,6 +213,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool: elif isinstance(self.sequential_targets, str): self.sequential_targets = get_layers(self.sequential_targets, self.model) + # infer update size + if self._update_size is None: + self._update_size = len(state.data.calib) + run_pipeline( state.model, self.sequential_targets, state.data.calib, propagate_error=True ) @@ -263,7 +266,7 @@ def compress_module( self._num_samples[module] = 0 # Accumulate hessian with input with optional offloading - with self._maybe_onload_hessians(module): + with self._maybe_onload_hessian(module): self._hessians[module], self._num_samples[module] = accumulate_hessian( inp, type(module), @@ -272,11 +275,11 @@ def compress_module( ) # After enough samples are accumulated, perform quantization - if self._num_samples[module] >= self.update_size: + if self._num_samples[module] >= self._update_size: logger.info(f"Quantizing {name} using {self._num_samples[module]} samples") with ( align_module_device(module), - self._maybe_onload_hessians(module), + self._maybe_onload_hessian(module), CompressionLogger(module) as comp_logger, ): loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( @@ -301,7 +304,7 @@ def compress_module( comp_logger.set_loss(loss) @contextlib.contextmanager - def _maybe_onload_hessians(self, module: torch.nn.Module): + def _maybe_onload_hessian(self, module: torch.nn.Module): if self.offload_hessians: device = get_execution_device(module) self._hessians[module] = self._hessians[module].to(device=device) diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index 853ad22aa..d3f18a839 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -5,6 +5,8 @@ import torch.utils.data.dataloader import tqdm +from llmcompressor.core import callbacks as session_callbacks +from llmcompressor.modifiers.modifier import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.piecewise.helpers import ( @@ -19,7 +21,7 @@ def run_pipeline( model: torch.nn.Module, - targets: List[str], # future: replace with recipe + targets: List[str], # FUTURE: replace with recipe dataloader: torch.utils.data.DataLoader, propagate_error: bool, ): @@ -90,3 +92,4 @@ def run_pipeline( intermediates.update(subgraph_output) else: batch_outputs[batch_index] = subgraph_output + From f9ab6fc003e603f0611d6b981abea04f37afd2bf Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 18:08:32 +0000 Subject: [PATCH 161/285] make a module --- examples/quantization_w4a16/llama3_example.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index c08165299..2690ae780 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,3 +1,4 @@ +from accelerate.big_modeling import cpu_offload from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -5,13 +6,15 @@ from llmcompressor.transformers import oneshot # Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, - device_map="auto", + # device_map="auto", + device_map="cuda:0", torch_dtype="auto", ) +cpu_offload(model) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. @@ -55,7 +58,9 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +recipe = GPTQModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head"], offload_hessians=False +) # Apply algorithms. oneshot( From 0dc74dd2294132e6c9659ac96ade48bc7a042e78 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 20:44:55 +0000 Subject: [PATCH 162/285] fallback --- examples/multimodal_vision/mllama.py | 1 - examples/multimodal_vision/pixtral.py | 1 - examples/multimodal_vision/qwen_vl2.py | 1 - .../modifiers/quantization/gptq/base.py | 14 +++++--- src/llmcompressor/pipelines/__init__.py | 0 src/llmcompressor/pipelines/basic/__init__.py | 1 + src/llmcompressor/pipelines/basic/pipeline.py | 36 +++++++++++++++++++ 7 files changed, 47 insertions(+), 7 deletions(-) create mode 100644 src/llmcompressor/pipelines/__init__.py create mode 100644 src/llmcompressor/pipelines/basic/__init__.py create mode 100644 src/llmcompressor/pipelines/basic/pipeline.py diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 46f9acabe..3d1ba24af 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -40,7 +40,6 @@ def data_collator(batch): targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], - update_size=NUM_CALIBRATION_SAMPLES, ), ] diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 14187d7d1..f8aaea54c 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -38,7 +38,6 @@ def data_collator(batch): targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], - update_size=NUM_CALIBRATION_SAMPLES, ), ] diff --git a/examples/multimodal_vision/qwen_vl2.py b/examples/multimodal_vision/qwen_vl2.py index 2de17d4ae..5e470bdc3 100644 --- a/examples/multimodal_vision/qwen_vl2.py +++ b/examples/multimodal_vision/qwen_vl2.py @@ -57,7 +57,6 @@ def data_collator(batch): ), }, ignore=["re:.*lm_head"], - update_size=NUM_CALIBRATION_SAMPLES, dampening_frac=0.5, ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 11e3f8440..a2f706c00 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -24,7 +24,8 @@ ) from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.pipelines.piecewise import run_pipeline +from llmcompressor.pipelines.basic import run_pipeline as run_basic +from llmcompressor.pipelines.piecewise import run_pipeline as run_piecewise from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import ( get_layers, @@ -217,9 +218,14 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if self._update_size is None: self._update_size = len(state.data.calib) - run_pipeline( - state.model, self.sequential_targets, state.data.calib, propagate_error=True - ) + # run_pipeline( + # state.model, self.sequential_targets, state.data.calib, propagate_error=True + # ) + + self.offload_hessians = True + run_basic(state.model, state.data.calib) + + return True diff --git a/src/llmcompressor/pipelines/__init__.py b/src/llmcompressor/pipelines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llmcompressor/pipelines/basic/__init__.py b/src/llmcompressor/pipelines/basic/__init__.py new file mode 100644 index 000000000..2ae7cbad3 --- /dev/null +++ b/src/llmcompressor/pipelines/basic/__init__.py @@ -0,0 +1 @@ +from .pipeline import run_pipeline \ No newline at end of file diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py new file mode 100644 index 000000000..58a1362f0 --- /dev/null +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -0,0 +1,36 @@ +from contextlib import nullcontext +from typing import List + +import torch +import torch.utils.data.dataloader +import tqdm + +from llmcompressor.core import callbacks as session_callbacks +from llmcompressor.modifiers.modifier import Modifier +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pipelines.piecewise.helpers import ( + infer_sequential_targets, + trace_subgraphs, +) +from llmcompressor.pytorch.utils.helpers import tensors_to_device +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["run_pipeline"] + +def run_pipeline( + model: torch.nn.Module, + dataloader: torch.utils.data.DataLoader, +): + # TODO: revisit + device_map = getattr(model, "hf_device_map", None) + if device_map is not None: + model_device = next(iter(device_map.values())) + else: + model_device = model.device + + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): + batch = apply_pad_mask_to_batch(batch) + batch = tensors_to_device(batch, model_device) + model(**batch) + From 9e07188e8c44a3020860fe5cf92ae227efc1e9f9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 16:13:59 -0500 Subject: [PATCH 163/285] implement basic pipeline --- examples/multimodal_vision/pixtral.py | 4 ++-- src/llmcompressor/core/session_functions.py | 2 +- .../modifiers/quantization/gptq/base.py | 22 +++++++++++-------- src/llmcompressor/pipelines/basic/__init__.py | 3 ++- src/llmcompressor/pipelines/basic/pipeline.py | 21 +++++------------- .../pipelines/piecewise/pipeline.py | 3 --- 6 files changed, 24 insertions(+), 31 deletions(-) diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index f8aaea54c..73c445427 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -17,7 +17,7 @@ # Oneshot arguments DATASET_ID = "flickr30k" DATASET_SPLIT = "test[:512]" -NUM_CALIBRATION_SAMPLES = 1 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 @@ -27,7 +27,7 @@ def data_collator(batch): return { "input_ids": torch.LongTensor(batch[0]["input_ids"]), "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"])[0], } diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index f08b314e8..c30bb08fd 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -268,7 +268,7 @@ def optim_post_step(cls, **kwargs) -> ModifiedState: :return: the modified state of the active session after invoking the event """ return cls.event(EventType.OPTIM_POST_STEP, **kwargs) - + @classmethod def sub_batch_end(cls, **kwargs) -> ModifiedState: cls.event(EventType.SUB_BATCH_END, **kwargs) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index a2f706c00..9cb1febff 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -74,7 +74,7 @@ class GPTQModifier(Modifier, HooksMixin): :param sequential_targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model - + :param block_size: Used to determine number of columns to compress in one pass :param quantize: Set to True to quantize using an existing quantization modifier, or pass in the configuration for a quantization modifier if one does not @@ -218,14 +218,17 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if self._update_size is None: self._update_size = len(state.data.calib) - # run_pipeline( - # state.model, self.sequential_targets, state.data.calib, propagate_error=True - # ) - - self.offload_hessians = True - run_basic(state.model, state.data.calib) - - + # infer pipeline + if "pixel_values" not in state.data.calib.dataset.column_names: + run_piecewise( + state.model, + self.sequential_targets, + state.data.calib, + propagate_error=True, + ) + else: + self.offload_hessians = True + run_basic(state.model, state.data.calib) return True @@ -284,6 +287,7 @@ def compress_module( if self._num_samples[module] >= self._update_size: logger.info(f"Quantizing {name} using {self._num_samples[module]} samples") with ( + torch.no_grad(), align_module_device(module), self._maybe_onload_hessian(module), CompressionLogger(module) as comp_logger, diff --git a/src/llmcompressor/pipelines/basic/__init__.py b/src/llmcompressor/pipelines/basic/__init__.py index 2ae7cbad3..fc60475ca 100644 --- a/src/llmcompressor/pipelines/basic/__init__.py +++ b/src/llmcompressor/pipelines/basic/__init__.py @@ -1 +1,2 @@ -from .pipeline import run_pipeline \ No newline at end of file +# flake8: noqa +from .pipeline import run_pipeline diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 58a1362f0..31f4c01be 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,23 +1,14 @@ -from contextlib import nullcontext -from typing import List - import torch import torch.utils.data.dataloader import tqdm -from llmcompressor.core import callbacks as session_callbacks -from llmcompressor.modifiers.modifier import Modifier -from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch -from llmcompressor.pipelines.piecewise.helpers import ( - infer_sequential_targets, - trace_subgraphs, -) from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context __all__ = ["run_pipeline"] + def run_pipeline( model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, @@ -29,8 +20,8 @@ def run_pipeline( else: model_device = model.device - for batch in tqdm.tqdm(dataloader, desc="Calibrating"): - batch = apply_pad_mask_to_batch(batch) - batch = tensors_to_device(batch, model_device) - model(**batch) - + with calibration_forward_context(model): + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): + batch = apply_pad_mask_to_batch(batch) + batch = tensors_to_device(batch, model_device) + model(**batch) diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index d3f18a839..262d6c2ed 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -5,8 +5,6 @@ import torch.utils.data.dataloader import tqdm -from llmcompressor.core import callbacks as session_callbacks -from llmcompressor.modifiers.modifier import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.piecewise.helpers import ( @@ -92,4 +90,3 @@ def run_pipeline( intermediates.update(subgraph_output) else: batch_outputs[batch_index] = subgraph_output - From ed099ef498cc793b97f13ffd4a137b7293d7e418 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 16:27:38 -0500 Subject: [PATCH 164/285] balance between gpus --- examples/multimodal_vision/pixtral.py | 2 +- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 73c445427..b4f1f1ff0 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -10,7 +10,7 @@ # Load model. model_id = "mgoin/pixtral-12b" model = TracableLlavaForConditionalGeneration.from_pretrained( - model_id, device_map="auto", torch_dtype="auto" + model_id, device_map="balanced", torch_dtype="auto" ) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 9cb1febff..ae949e618 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -227,7 +227,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: propagate_error=True, ) else: - self.offload_hessians = True + # self.offload_hessians = True run_basic(state.model, state.data.calib) return True From 4bbbc49ecdf9ba251f3a8636cb4a54e66231be6c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 21:57:04 +0000 Subject: [PATCH 165/285] add proper ignore list --- examples/multimodal_vision/pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index b4f1f1ff0..13efe4c20 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -37,7 +37,7 @@ def data_collator(batch): GPTQModifier( targets="Linear", scheme="W8A8", - ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], + ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], ), ] From ae74f4551ee5893f6ce6a6f65c89b69065820c1d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 7 Dec 2024 00:10:54 -0500 Subject: [PATCH 166/285] treat offloaded modules as leaves, treat ignore as sequential target --- .../modifiers/quantization/gptq/base.py | 3 +- .../pipelines/piecewise/helpers.py | 31 ++++++++++++++++--- .../pipelines/piecewise/pipeline.py | 9 ++++-- src/llmcompressor/pytorch/tracing/llava.py | 4 ++- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index ae949e618..fba4ac3c6 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -219,10 +219,11 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._update_size = len(state.data.calib) # infer pipeline - if "pixel_values" not in state.data.calib.dataset.column_names: + if True: # if "pixel_values" not in state.data.calib.dataset.column_names: run_piecewise( state.model, self.sequential_targets, + self.ignore, state.data.calib, propagate_error=True, ) diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py index e558f65a6..383f7dc91 100644 --- a/src/llmcompressor/pipelines/piecewise/helpers.py +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Any, Dict, List, Set +from compressed_tensors import has_offloaded_params +from compressed_tensors.quantization import find_name_or_class_matches from compressed_tensors.utils import disable_hf_hook from torch.fx import Graph, GraphModule, Node from torch.nn import Module @@ -22,21 +24,31 @@ class Subgraph: __all__ = ["infer_sequential_targets", "trace_subgraphs"] -def infer_sequential_targets(model: Module, targets: List[str]) -> Set[Module]: +def infer_sequential_targets( + model: Module, sequential_targets: List[str], ignore: List[str] +) -> Set[Module]: """ Future: infer from recipe List of modules which are guaranteed to be split into different partitions and whose inner operations will not be traced """ - return set(module for module in model.modules() if type(module).__name__ in targets) + targets_names = sequential_targets + ignore + + sequential_targets = set( + module + for name, module in model.named_modules() + if find_name_or_class_matches(name, module, targets_names) + ) + + return sequential_targets def trace_subgraphs( model: Module, sample_input: Dict[str, Any], targets: Set[Module] ) -> List[Subgraph]: # initialize arguments - tracer = get_tracer(targets) + tracer = get_tracer(model, targets) concrete_args = populate_concrete_args(model, sample_input) # trace @@ -70,11 +82,20 @@ def trace_subgraphs( return subgraphs -def get_tracer(targets: List[Module]) -> HFTracer: +def get_tracer(model: Module, targets: List[Module]) -> HFTracer: + offloaded_modules = set( + module for module in model.modules() if has_offloaded_params(module) + ) + class PiecewiseTracer(HFTracer): + # Treat as leaf, skip tracing inside this module def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: if module in targets: - return True # Treat as leaf, skip tracing inside this module + return True + + if module in offloaded_modules: + return True + return super().is_leaf_module(module, module_qualified_name) return PiecewiseTracer() diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index 262d6c2ed..f3b517dcb 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -19,13 +19,14 @@ def run_pipeline( model: torch.nn.Module, - targets: List[str], # FUTURE: replace with recipe + sequential_targets: List[str], # FUTURE: replace with recipe inference + ignore: List[str], dataloader: torch.utils.data.DataLoader, propagate_error: bool, ): # trace subgraphs sample_input = next(iter(dataloader)) - targets = infer_sequential_targets(model, targets) + targets = infer_sequential_targets(model, sequential_targets, ignore) subgraphs = trace_subgraphs(model, sample_input, targets) # FUTURE: apply recipe to model @@ -64,6 +65,10 @@ def run_pipeline( input_name: intermediates[input_name] for input_name in subgraph.input_names } + # TODO: put on first device from + # subgraph.graph.find_nodes(op="call_module") + # since find_nodes is topologically sorted + # or get execution device of GraphModule, recursively inputs = tensors_to_device(inputs, model_device) forward_function(model, **inputs) diff --git a/src/llmcompressor/pytorch/tracing/llava.py b/src/llmcompressor/pytorch/tracing/llava.py index 23f7a6bbb..89da7cdea 100644 --- a/src/llmcompressor/pytorch/tracing/llava.py +++ b/src/llmcompressor/pytorch/tracing/llava.py @@ -120,7 +120,9 @@ def forward( n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: + # NOT TRACABLE, instead always use n_image_tokens != n_image_features = False + #if n_image_tokens != n_image_features: + if False: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) From 31eeb8c9841a1bc8da52f50d230fe265402f95f4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 7 Dec 2024 00:18:41 -0500 Subject: [PATCH 167/285] redisable piecewise for vision datasets --- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index fba4ac3c6..abaac931f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -219,7 +219,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._update_size = len(state.data.calib) # infer pipeline - if True: # if "pixel_values" not in state.data.calib.dataset.column_names: + if "pixel_values" not in state.data.calib.dataset.column_names: run_piecewise( state.model, self.sequential_targets, From 1b24090efd4e06adf85debe78f71f1842bd955f0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 12:26:59 -0500 Subject: [PATCH 168/285] implement pipeline fallback --- examples/multimodal_vision/pixtral.py | 2 ++ .../modifiers/quantization/gptq/base.py | 28 +++++++++++++------ .../pipelines/piecewise/pipeline.py | 3 +- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 13efe4c20..dee690480 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -38,6 +38,7 @@ def data_collator(batch): targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], + sequential_targets=["MistralDecoderLayer"], ), ] @@ -58,6 +59,7 @@ def data_collator(batch): data_collator=data_collator, ) +model.save_pretrained(save_path) processor.save_pretrained(save_path) # Confirm generations of the quantized model look sane. diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index abaac931f..056aa4c41 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,4 +1,5 @@ import contextlib +import traceback import warnings from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -220,15 +221,26 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer pipeline if "pixel_values" not in state.data.calib.dataset.column_names: - run_piecewise( - state.model, - self.sequential_targets, - self.ignore, - state.data.calib, - propagate_error=True, - ) + try: + run_piecewise( + state.model, + self.sequential_targets, + self.ignore, + state.data.calib, + propagate_error=True, + ) + + except torch.fx.proxy.TraceError: + print(traceback.format_exc()) + warnings.warn( + "Failed to trace model graph, using non-sequential " + "pipeline with `offload_hessians = True`" + ) + self.offload_hessians = True + run_basic(state.model, state.data.calib) + else: - # self.offload_hessians = True + warnings.warn("Cannot use sequential pipeline with vision datasets") run_basic(state.model, state.data.calib) return True diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index f3b517dcb..be9d062cc 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -41,8 +41,9 @@ def run_pipeline( with calibration_forward_context(model): # prepare intermediates cache + desc = "Preparing intermediates cache" batch_intermediates = [ - apply_pad_mask_to_batch(batch) for batch in iter(dataloader) + apply_pad_mask_to_batch(batch) for batch in tqdm.tqdm(dataloader, desc=desc) ] batch_outputs = [None for _ in range(len(dataloader))] From e87e019f5cbbcffc666f444da4820fd519ffc616 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 12:38:21 -0500 Subject: [PATCH 169/285] remove subbatch event --- src/llmcompressor/core/events/event.py | 3 --- src/llmcompressor/core/session_functions.py | 4 ---- 2 files changed, 7 deletions(-) diff --git a/src/llmcompressor/core/events/event.py b/src/llmcompressor/core/events/event.py index a31301189..9d5d48d63 100644 --- a/src/llmcompressor/core/events/event.py +++ b/src/llmcompressor/core/events/event.py @@ -45,7 +45,6 @@ class EventType(Enum): # batch lifecycle BATCH_START = "batch_start" LOSS_CALCULATED = "loss_calculated" - SUB_BATCH_END = "sub_batch_end" BATCH_END = "batch_end" # step lifecycle @@ -75,8 +74,6 @@ def order(self) -> int: return 120 elif self == EventType.OPTIM_POST_STEP: return 130 - elif self == EventType.SUB_BATCH_END: - return 135 elif self == EventType.BATCH_END: return 140 else: diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index c30bb08fd..9a123a030 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -269,10 +269,6 @@ def optim_post_step(cls, **kwargs) -> ModifiedState: """ return cls.event(EventType.OPTIM_POST_STEP, **kwargs) - @classmethod - def sub_batch_end(cls, **kwargs) -> ModifiedState: - cls.event(EventType.SUB_BATCH_END, **kwargs) - @classmethod def batch_end(cls, **kwargs) -> ModifiedState: """ From d5c08fbf4a73ac558e59e183630129e85e0c5f0a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 12:57:16 -0500 Subject: [PATCH 170/285] input device inference --- examples/quantization_w4a16/llama3_example.py | 11 +++-------- src/llmcompressor/pipelines/piecewise/helpers.py | 15 ++++++++++++++- src/llmcompressor/pipelines/piecewise/pipeline.py | 15 ++------------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 2690ae780..c08165299 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,4 +1,3 @@ -from accelerate.big_modeling import cpu_offload from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -6,15 +5,13 @@ from llmcompressor.transformers import oneshot # Select model and load it. -MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, - # device_map="auto", - device_map="cuda:0", + device_map="auto", torch_dtype="auto", ) -cpu_offload(model) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. @@ -58,9 +55,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier( - targets="Linear", scheme="W4A16", ignore=["lm_head"], offload_hessians=False -) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) # Apply algorithms. oneshot( diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py index 383f7dc91..f409a4275 100644 --- a/src/llmcompressor/pipelines/piecewise/helpers.py +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Set +import torch from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches from compressed_tensors.utils import disable_hf_hook @@ -19,6 +20,7 @@ class Subgraph: graph: Graph input_names: List[str] consumed_names: List[str] + input_device: torch.device __all__ = ["infer_sequential_targets", "trace_subgraphs"] @@ -235,7 +237,17 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap } graph.output(output_dict) - # Save the subgraph for this partition + # find input device for subgraph + # note find_nodes is topologically sorted + modules = graph.find_nodes(op="call_module") + first_offloaded = next((m for m in modules if has_offloaded_params(m)), None) + input_device = ( + first_offloaded.execution_device + if first_offloaded is not None + else model.device + ) + + # save the subgraph for this partition graph.lint() input_names = [node.name for node in graph.nodes if node.op == "placeholder"] subgraphs.append( @@ -243,6 +255,7 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap graph=graph, input_names=input_names, consumed_names=[], # populated later + input_device=input_device, ) ) diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index be9d062cc..4efb7d954 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -32,13 +32,6 @@ def run_pipeline( # FUTURE: apply recipe to model # initialize(recipe, model) - # TODO: revisit - device_map = getattr(model, "hf_device_map", None) - if device_map is not None: - model_device = next(iter(device_map.values())) - else: - model_device = model.device - with calibration_forward_context(model): # prepare intermediates cache desc = "Preparing intermediates cache" @@ -66,11 +59,7 @@ def run_pipeline( input_name: intermediates[input_name] for input_name in subgraph.input_names } - # TODO: put on first device from - # subgraph.graph.find_nodes(op="call_module") - # since find_nodes is topologically sorted - # or get execution device of GraphModule, recursively - inputs = tensors_to_device(inputs, model_device) + inputs = tensors_to_device(inputs, subgraph.input_device) forward_function(model, **inputs) # if using propagate_error, then this pass does not trigger modifier hooks @@ -85,7 +74,7 @@ def run_pipeline( input_name: intermediates[input_name] for input_name in subgraph.input_names } - inputs = tensors_to_device(inputs, model_device) + inputs = tensors_to_device(inputs, subgraph.input_device) subgraph_output = forward_function(model, **inputs) subgraph_output = tensors_to_device(subgraph_output, "cpu") From 39ed8cae0347c2e84395e708efc4b0982babaea8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 19:16:30 +0000 Subject: [PATCH 171/285] do not disable hf hook during tracing --- src/llmcompressor/pipelines/piecewise/helpers.py | 13 ++++++++----- src/llmcompressor/pipelines/piecewise/pipeline.py | 2 ++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py index f409a4275..0b1bc99b4 100644 --- a/src/llmcompressor/pipelines/piecewise/helpers.py +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -12,7 +12,7 @@ from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.utils.helpers import calibration_forward_context +from llmcompressor.utils.helpers import calibration_forward_context, getattr_chain @dataclass @@ -57,7 +57,7 @@ def trace_subgraphs( with ( calibration_forward_context(model), HooksMixin.disable_hooks(), - disable_hf_hook(model, recurse=True), + # disable_hf_hook(model, recurse=True), ): graph = GraphModule( model, @@ -238,11 +238,14 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap graph.output(output_dict) # find input device for subgraph - # note find_nodes is topologically sorted - modules = graph.find_nodes(op="call_module") + # note: find_nodes is topologically sorted + modules = [ + getattr_chain(model, node.target) + for node in graph.find_nodes(op="call_module") + ] first_offloaded = next((m for m in modules if has_offloaded_params(m)), None) input_device = ( - first_offloaded.execution_device + torch.device(first_offloaded._hf_hook.execution_device) if first_offloaded is not None else model.device ) diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/piecewise/pipeline.py index 4efb7d954..115f70085 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/piecewise/pipeline.py @@ -59,6 +59,8 @@ def run_pipeline( input_name: intermediates[input_name] for input_name in subgraph.input_names } + # graph_module = torch.fx.GraphModule(model, subgraph.graph) + # breakpoint() inputs = tensors_to_device(inputs, subgraph.input_device) forward_function(model, **inputs) From 4711e9ff399adb45ae092ff3a009f26d944f9ad5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 20:03:16 +0000 Subject: [PATCH 172/285] remove import --- src/llmcompressor/pipelines/piecewise/helpers.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py index 0b1bc99b4..a44473031 100644 --- a/src/llmcompressor/pipelines/piecewise/helpers.py +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -6,7 +6,6 @@ import torch from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches -from compressed_tensors.utils import disable_hf_hook from torch.fx import Graph, GraphModule, Node from torch.nn import Module from transformers.utils.fx import HFTracer @@ -92,13 +91,11 @@ def get_tracer(model: Module, targets: List[Module]) -> HFTracer: class PiecewiseTracer(HFTracer): # Treat as leaf, skip tracing inside this module def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: - if module in targets: - return True - - if module in offloaded_modules: - return True - - return super().is_leaf_module(module, module_qualified_name) + return ( + module in targets + or module in offloaded_modules + or super().is_leaf_module(module, module_qualified_name) + ) return PiecewiseTracer() From e468197dc1ca64b087d8547e5c67fa3e1d90b96a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 20:07:22 +0000 Subject: [PATCH 173/285] use find_nodes --- .../pipelines/piecewise/helpers.py | 19 +++++++++---------- .../finetune/data/test_dataset_loading.py | 1 - 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/piecewise/helpers.py index a44473031..4c1d3c017 100644 --- a/src/llmcompressor/pipelines/piecewise/helpers.py +++ b/src/llmcompressor/pipelines/piecewise/helpers.py @@ -186,16 +186,15 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List # a perfect solution would involve implicitly consolodating partition indices so # that each node is assigned to the maximum partition possible (in order to delay # execution as long as possible), but this covers the most costly case (get_attr) - for node in graph.graph.nodes: - if node.op == "get_attr": - user_partitions = [] - for user in node.users: - for index in range(len(partitions)): - if user in partitions[index]: - user_partitions.append(index) - break - partition_index = min(user_partitions) - partitions[partition_index].insert(0, node) + for node in graph.graph.find_nodes(op="get_attr"): + user_partitions = [] + for user in node.users: + for index in range(len(partitions)): + if user in partitions[index]: + user_partitions.append(index) + break + partition_index = min(user_partitions) + partitions[partition_index].insert(0, node) assert set().union(*partitions) == set(graph.graph.nodes) return partitions diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index ef60cb811..64514b252 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -15,7 +15,6 @@ format_calibration_data, ) from llmcompressor.transformers.finetune.runner import StageRunner -from llmcompressor.transformers.finetune.training_args import TrainingArguments @pytest.mark.unit From f8591cac3dbf0f0eb9604afab1bff1eee530a6ca Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 20:12:00 +0000 Subject: [PATCH 174/285] rename piecewise to sequential --- src/llmcompressor/modifiers/quantization/gptq/base.py | 4 ++-- .../pipelines/{piecewise => sequential}/__init__.py | 0 .../pipelines/{piecewise => sequential}/helpers.py | 0 .../pipelines/{piecewise => sequential}/pipeline.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename src/llmcompressor/pipelines/{piecewise => sequential}/__init__.py (100%) rename src/llmcompressor/pipelines/{piecewise => sequential}/helpers.py (100%) rename src/llmcompressor/pipelines/{piecewise => sequential}/pipeline.py (98%) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 056aa4c41..8c66c272b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -26,7 +26,7 @@ from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.basic import run_pipeline as run_basic -from llmcompressor.pipelines.piecewise import run_pipeline as run_piecewise +from llmcompressor.pipelines.sequential import run_pipeline as run_sequential from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import ( get_layers, @@ -222,7 +222,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer pipeline if "pixel_values" not in state.data.calib.dataset.column_names: try: - run_piecewise( + run_sequential( state.model, self.sequential_targets, self.ignore, diff --git a/src/llmcompressor/pipelines/piecewise/__init__.py b/src/llmcompressor/pipelines/sequential/__init__.py similarity index 100% rename from src/llmcompressor/pipelines/piecewise/__init__.py rename to src/llmcompressor/pipelines/sequential/__init__.py diff --git a/src/llmcompressor/pipelines/piecewise/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py similarity index 100% rename from src/llmcompressor/pipelines/piecewise/helpers.py rename to src/llmcompressor/pipelines/sequential/helpers.py diff --git a/src/llmcompressor/pipelines/piecewise/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py similarity index 98% rename from src/llmcompressor/pipelines/piecewise/pipeline.py rename to src/llmcompressor/pipelines/sequential/pipeline.py index 115f70085..7c91fd0b4 100644 --- a/src/llmcompressor/pipelines/piecewise/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -7,7 +7,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch -from llmcompressor.pipelines.piecewise.helpers import ( +from llmcompressor.pipelines.sequential.helpers import ( infer_sequential_targets, trace_subgraphs, ) From cea02d25eac4f3bf76f2a01af199cf38b24713a8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 20:56:12 +0000 Subject: [PATCH 175/285] add docstring --- .../pipelines/sequential/pipeline.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 7c91fd0b4..9e3f7fe70 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -24,6 +24,22 @@ def run_pipeline( dataloader: torch.utils.data.DataLoader, propagate_error: bool, ): + """ + Run a sequential data pipeline. First, the model is partitioned into subgraphs + according to `sequential_targets`. Then, data passes through each subgraph + sequentially. If `propagate_error` is enabled, then data is passed through each + subgraph twice, once to trigger calibration hooks, then a second time in order to + capture activations after quantization has occurred through the hooks. + + In order to reduce memory requirements + 1. Data is passed through each subgraph with batch size 1 + 2. The intermediate activations between each subgraph are offloaded onto the CPU + + This pipeline requires that the model be tracable with respect to data from the + data loader. This may be an issue for vision language models with vision datasets, + due to specialized input processing in the model. In the event that tracing fails, + a torch.fx.proxy.TraceError will be raised. + """ # trace subgraphs sample_input = next(iter(dataloader)) targets = infer_sequential_targets(model, sequential_targets, ignore) From f1f6c0f7d1c3ce3a55c526200910598814335441 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Dec 2024 23:48:23 +0000 Subject: [PATCH 176/285] begin sequential pipeline testing --- examples/quantization_w4a16/llama3_example.py | 5 ++++- src/llmcompressor/pipelines/sequential/__init__.py | 1 - src/llmcompressor/pipelines/sequential/pipeline.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index c08165299..a294a9b39 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,3 +1,4 @@ +from accelerate import cpu_offload from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -10,8 +11,10 @@ model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", + # device_map="cuda:0", torch_dtype="auto", ) +# cpu_offload(model) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. @@ -20,7 +23,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 1 # 512 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. diff --git a/src/llmcompressor/pipelines/sequential/__init__.py b/src/llmcompressor/pipelines/sequential/__init__.py index 2b0a117ce..fc60475ca 100644 --- a/src/llmcompressor/pipelines/sequential/__init__.py +++ b/src/llmcompressor/pipelines/sequential/__init__.py @@ -1,3 +1,2 @@ # flake8: noqa -from .helpers import * from .pipeline import run_pipeline diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 9e3f7fe70..73325a7e4 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -52,7 +52,8 @@ def run_pipeline( # prepare intermediates cache desc = "Preparing intermediates cache" batch_intermediates = [ - apply_pad_mask_to_batch(batch) for batch in tqdm.tqdm(dataloader, desc=desc) + apply_pad_mask_to_batch(batch) if "attention_mask" in batch else batch + for batch in tqdm.tqdm(dataloader, desc=desc) ] batch_outputs = [None for _ in range(len(dataloader))] From 3b0b49f7a9d5eededca5836e1ea4a9e3257ced8a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Dec 2024 02:24:15 +0000 Subject: [PATCH 177/285] remove todos, add tests for sequential pipeline --- examples/quantization_w4a16/llama3_example.py | 3 --- .../quantization/gptq/utils/gptq_quantize.py | 1 - .../modifiers/utils/pytorch_helpers.py | 1 - .../pipelines/sequential/helpers.py | 20 +++++++++++-------- .../pipelines/sequential/pipeline.py | 2 -- src/llmcompressor/pytorch/tracing/llava.py | 2 +- 6 files changed, 13 insertions(+), 16 deletions(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index a294a9b39..fe574293b 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -1,4 +1,3 @@ -from accelerate import cpu_offload from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -11,10 +10,8 @@ model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", - # device_map="cuda:0", torch_dtype="auto", ) -# cpu_offload(model) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index a6f3e7d48..fc4b56edc 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -154,7 +154,6 @@ def quantize_weight( W[:, dead] = 0 # compute inverse hessian in place to save memory - # TODO: check in place Hinv = _invert_hessian(H, percdamp) # See section 3.4 of https://arxiv.org/abs/2203.07259 diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 96769119c..7de7dc58d 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -101,7 +101,6 @@ def run_calibration_forward( # model was stopped early, save last calculated output and # move on to next calibration sample intermediates.append((e.args, e.kwargs)) - pass # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 4c1d3c017..2331e30ba 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -56,7 +56,6 @@ def trace_subgraphs( with ( calibration_forward_context(model), HooksMixin.disable_hooks(), - # disable_hf_hook(model, recurse=True), ): graph = GraphModule( model, @@ -239,12 +238,18 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap getattr_chain(model, node.target) for node in graph.find_nodes(op="call_module") ] - first_offloaded = next((m for m in modules if has_offloaded_params(m)), None) - input_device = ( - torch.device(first_offloaded._hf_hook.execution_device) - if first_offloaded is not None - else model.device - ) + if len(modules) > 0: + first_offloaded = next( + (m for m in modules if has_offloaded_params(m)), None + ) + input_device = ( + torch.device(first_offloaded._hf_hook.execution_device) + if first_offloaded is not None + else next(modules[0].parameters()).device + ) + + else: + input_device = model.device # save the subgraph for this partition graph.lint() @@ -264,7 +269,6 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap def trace_consumed_names(subgraphs: List[Dict[str, Any]]): - # TODO: update consumed names as new partitions are appended # populate consumed_names according to when inputs are last used # in order to vacate the `intermediates` cache and save memory all_input_names = set().union(*(subgraph.input_names for subgraph in subgraphs)) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 73325a7e4..1325712de 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -76,8 +76,6 @@ def run_pipeline( input_name: intermediates[input_name] for input_name in subgraph.input_names } - # graph_module = torch.fx.GraphModule(model, subgraph.graph) - # breakpoint() inputs = tensors_to_device(inputs, subgraph.input_device) forward_function(model, **inputs) diff --git a/src/llmcompressor/pytorch/tracing/llava.py b/src/llmcompressor/pytorch/tracing/llava.py index 89da7cdea..6ee76f299 100644 --- a/src/llmcompressor/pytorch/tracing/llava.py +++ b/src/llmcompressor/pytorch/tracing/llava.py @@ -115,7 +115,7 @@ def forward( position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - # TODO: @raushan retain only the new behavior after v4.47 + # @raushan retain only the new behavior after v4.47 elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] * image_features.shape[1] From 2c035b311808b3764248fbe3f3eef2db94e0e4af Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Dec 2024 02:25:52 +0000 Subject: [PATCH 178/285] move function placement --- .../pipelines/sequential/helpers.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 2331e30ba..ab072ddcb 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -128,24 +128,6 @@ def get_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]: ) -def check_assumption(graph: Graph) -> bool: - for node in graph.nodes: - for user in node.users: - if node not in user.all_input_nodes: - return False - - for input_node in node.all_input_nodes: - if node not in input_node.users: - return False - - if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len( - set(node.all_input_nodes) - ): - return False - - return True - - def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List[Node]]: assert check_assumption(graph.graph) target_nodes = get_target_nodes(graph, targets) @@ -279,3 +261,21 @@ def trace_consumed_names(subgraphs: List[Dict[str, Any]]): break else: assert False + + +def check_assumption(graph: Graph) -> bool: + for node in graph.nodes: + for user in node.users: + if node not in user.all_input_nodes: + return False + + for input_node in node.all_input_nodes: + if node not in input_node.users: + return False + + if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len( + set(node.all_input_nodes) + ): + return False + + return True From b93868d1f4fad1b125582ac297783b210bb924db Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Dec 2024 02:31:37 +0000 Subject: [PATCH 179/285] slight partition algorithm change --- src/llmcompressor/pipelines/sequential/helpers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index ab072ddcb..4fd4db4ff 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -148,15 +148,14 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List while len(queue) > 0: node = queue.popleft() - # TODO: test swapping with below + # assign to partition + partitions[partition_index].append(node) + # guarantee targets are assigned to disjoint partitions if node in target_nodes: partition_index += 1 partitions.append([]) - # assign to partition - partitions[partition_index].append(node) - # recurse on last indegree only in order to guarantee that # the node is assigned to maximal partition for user in node.users: From 146e4be0669516ac7a4d5dce742b52bcbbc18fd3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Dec 2024 02:32:34 +0000 Subject: [PATCH 180/285] revert llama3 example --- examples/quantization_w4a16/llama3_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index fe574293b..c08165299 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -20,7 +20,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 1 # 512 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. From ccb007fecd84af8346e8cdd9ed86a9e66226f6e9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Dec 2024 14:28:40 -0500 Subject: [PATCH 181/285] remove test, fix default in order to fix tests --- .../modifiers/quantization/gptq/base.py | 2 +- .../gptq/utils/test_gptq_wrapper.py | 41 ------------------- 2 files changed, 1 insertion(+), 42 deletions(-) delete mode 100644 tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 8c66c272b..0c3c96d34 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -120,7 +120,7 @@ class GPTQModifier(Modifier, HooksMixin): disable_quantization_observer_epoch: Optional[float] = None # private variables - _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() + _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr(default=None) _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) _update_size: Optional[int] = PrivateAttr(default=None) diff --git a/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py b/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py deleted file mode 100644 index 203d1fe03..000000000 --- a/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py +++ /dev/null @@ -1,41 +0,0 @@ -from collections import OrderedDict - -import torch -from compressed_tensors.quantization.lifecycle.apply import apply_quantization_config -from compressed_tensors.quantization.quant_config import QuantizationConfig -from compressed_tensors.quantization.quant_scheme import preset_name_to_scheme -from loguru import logger - -from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper - - -def test_ignore(): - model = torch.nn.Sequential( - OrderedDict( - [ - ("first_layer", torch.nn.Linear(2, 3)), - ("second_layer", torch.nn.Linear(3, 5)), - ] - ) - ) - - config = QuantizationConfig( - config_groups={"group_0": preset_name_to_scheme("W8A8", targets=["Linear"])}, - ignore=["first_layer"], - ) - apply_quantization_config(model, config) - - messages = [] - logger.add(lambda m: messages.append(m)) - - with torch.no_grad(): - first_compressor = GPTQWrapper("first_layer", model.first_layer) - first_compressor.add_batch(torch.ones(2), None) - first_compressor.compress() - - second_compressor = GPTQWrapper("second_layer", model.second_layer) - second_compressor.add_batch(torch.ones(3), None) - second_compressor.compress() - - assert sum("Skipping unquantized layer first_layer" in m for m in messages) == 1 - assert sum("Skipping unquantized layer second_layer" in m for m in messages) == 0 From e1055b0fd914587e9b8e7a0393f634f7216d0c80 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 11 Dec 2024 16:18:00 -0500 Subject: [PATCH 182/285] bump memory requirements --- .../modifiers/quantization/gptq/base.py | 23 ++--- .../quantization/gptq/utils/gptq_quantize.py | 94 +++++-------------- .../transformers/compression/helpers.py | 5 +- 3 files changed, 39 insertions(+), 83 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 0c3c96d34..bfe1f9715 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -306,25 +306,22 @@ def compress_module( CompressionLogger(module) as comp_logger, ): loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( - weight=module.weight.data, + module=module, quant_args=quant_args, - hessian=self._hessians[module], + hessians_dict=self._hessians, blocksize=self.block_size, percdamp=self.dampening_frac, - module_class=type(module), ) + comp_logger.set_loss(loss) - module.weight += quantized_weight - module.weight # Future: FSDP - update_offload_parameter(module, "weight", module.weight.data) - update_offload_parameter(module, "weight_scale", scale) - update_offload_parameter(module, "weight_zero_point", zero_point) - if g_idx is not None: - update_offload_parameter(module, "weight_g_idx", g_idx) - - del self._hessians[module] - del self._num_samples[module] + update_offload_parameter(module, "weight", quantized_weight) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) + if g_idx is not None: + update_offload_parameter(module, "weight_g_idx", g_idx) - comp_logger.set_loss(loss) + # self._hessians[module] already deleted by quantize_weight + del self._num_samples[module] @contextlib.contextmanager def _maybe_onload_hessian(self, module: torch.nn.Module): diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index fc4b56edc..d35fb9748 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,6 +1,6 @@ import math from copy import copy -from typing import Optional, Tuple, Type, Union +from typing import Dict, Optional, Tuple, Type, Union import torch import transformers @@ -58,31 +58,30 @@ def accumulate_hessian( def quantize_weight( - weight: torch.Tensor, + module: torch.nn.Module, quant_args: QuantizationArgs, - hessian: Optional[torch.Tensor] = None, - inp: Optional[torch.Tensor] = None, + hessians_dict: Dict[torch.nn.Module, torch.Tensor], blocksize: int = 128, percdamp: float = 0.01, - module_class: Type[torch.nn.Module] = torch.nn.Linear, ) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: """ Quantize a module weight according to the GPTQ algorithm - :param weight: weight being quantized + :param module: module with weight being quantized :param quant_args: quantization arguments used to find quantization parameters - :param hessian: preaccumulated hessian for quantization - :param inp: module inputs used to calculate hessian. Incompatible with `hessian` arg + :param hessian_dict: dictionary containing preaccumulated hessian for quantization :param blocksize: chunk size of quantization updates :param percdamp: dampening factor on hessian diagonal - :param module_class: class of module, likely torch.nn.Linear :return: loss, quantized_weight, scale, zero_point, g_idx """ strategy = quant_args.strategy actorder = quant_args.actorder - final_shape = weight.shape - final_dtype = weight.dtype - W = weight.data.clone() + final_shape = module.weight.shape + final_dtype = module.weight.dtype + module_class = type(module) + W = module.weight.clone() + H = hessians_dict[module] # unfortunately python does not have a `move` keyword + del hessians_dict[module] # so we have to delete the original reference manually # create observer for calculating quantization parameters observer = Observer.load_from_registry( @@ -100,16 +99,6 @@ def quantize_weight( num_rows = W.shape[0] num_columns = W.shape[1] - # compute hessian - if inp is not None: - if hessian is not None: - raise ValueError("Must pass either inp or hessian, but not both") - H = _compute_hessian(inp, module_class, device=weight.device) - elif hessian is not None: - H = hessian - else: - raise ValueError("Must pass either inp or hessian") - if strategy == QuantizationStrategy.GROUP: # mapping from column index to group index g_idx = ( @@ -146,7 +135,7 @@ def quantize_weight( else None ) - losses = torch.zeros(num_rows, device=weight.device) + losses = torch.zeros(num_rows, device=module.weight.device) # mask dead hessian values dead = torch.diag(H) == 0 @@ -154,7 +143,20 @@ def quantize_weight( W[:, dead] = 0 # compute inverse hessian in place to save memory - Hinv = _invert_hessian(H, percdamp) + try: + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + except torch._C._LinAlgError: + raise ValueError( + "Failed to invert hessian due to numerical instability. Consider " + "increasing GPTQModifier.dampening_frac, increasing the number " + "of calibration samples, or shuffling the calibration dataset" + ) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, num_columns, blocksize): @@ -265,50 +267,6 @@ def quantize_weight( ) -def _invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: - """ - Performs in-place inversion of the hessian in order to save memory - - :param H: hessian being inverted - :param percdamp: dampening factor on hessian diagonal - :return: inverted hessian - """ - damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(H.shape[0], device=H.device) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H, upper=True) - return H - - -def _compute_hessian( - inp: torch.Tensor, module_class: Type[torch.nn.Module], device -) -> torch.Tensor: - """ - Calculate the hessian with respect to the module inputs - - :param inp: module inputs - :param module_class: class of module, likely torch.nn.Linear - :return: hessian w.r.t. module inputs - """ - inp = inp.to(device=device) - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - - nsamples = inp.shape[0] # note this is the number of dataset samples, not - # multiplied by the sequence length - - if module_class in (torch.nn.Linear, transformers.Conv1D): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - - inp = inp.to(dtype=GPTQ_PRECISION) - inp = math.sqrt(2 / nsamples) * inp - return inp.matmul(inp.t()) - - def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 7c839c5a7..ee764b9f8 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -137,7 +137,8 @@ def hessian_memory_requirements(model: torch.nn.Module) -> int: max_total_hessian_elems = max(total_hessian_elems.values()) overall_max_column_size = max(max_column_size.values()) bytes_per_weight = 32 // 8 # hessians are float32 - inverse_reserved = overall_max_column_size * overall_max_column_size + # allocate enough space for out of place operations + inverse_reserved = overall_max_column_size * overall_max_column_size * 2 return (max_total_hessian_elems + inverse_reserved) * bytes_per_weight @@ -236,7 +237,7 @@ def calculate_offload_device_map( reserved_memory = 0 if reserve_for_hessians: - reserved_memory = hessian_memory_requirements(dummy_model) + reserved_memory = hessian_memory_requirements(dummy_model) * 2 reserved_memory += quantization_memory_requirement(dummy_model) memory_limits = { From 70421eddcdb1d8b8b2e5d3dd07a2eaa2321109ec Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 13:35:19 -0500 Subject: [PATCH 183/285] fix memory and offloading issues --- .../pipelines/sequential/helpers.py | 39 ++++--------- .../pipelines/sequential/pipeline.py | 57 +++++++------------ .../transformers/compression/helpers.py | 3 +- 3 files changed, 32 insertions(+), 67 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 4fd4db4ff..b35f309e5 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Any, Dict, List, Set -import torch from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches from torch.fx import Graph, GraphModule, Node @@ -11,15 +10,19 @@ from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.utils.helpers import calibration_forward_context, getattr_chain +from llmcompressor.utils.helpers import calibration_forward_context @dataclass class Subgraph: graph: Graph - input_names: List[str] - consumed_names: List[str] - input_device: torch.device + input_names: Set[str] + consumed_names: Set[str] + + def compile_forward(self): + code = self.graph.python_code("self") + exec(code.src, code.globals) + return code.globals.get("forward") __all__ = ["infer_sequential_targets", "trace_subgraphs"] @@ -213,34 +216,14 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap } graph.output(output_dict) - # find input device for subgraph - # note: find_nodes is topologically sorted - modules = [ - getattr_chain(model, node.target) - for node in graph.find_nodes(op="call_module") - ] - if len(modules) > 0: - first_offloaded = next( - (m for m in modules if has_offloaded_params(m)), None - ) - input_device = ( - torch.device(first_offloaded._hf_hook.execution_device) - if first_offloaded is not None - else next(modules[0].parameters()).device - ) - - else: - input_device = model.device - # save the subgraph for this partition graph.lint() - input_names = [node.name for node in graph.nodes if node.op == "placeholder"] + input_names = set(node.name for node in graph.nodes if node.op == "placeholder") subgraphs.append( Subgraph( graph=graph, input_names=input_names, - consumed_names=[], # populated later - input_device=input_device, + consumed_names=set(), # populated later ) ) @@ -256,7 +239,7 @@ def trace_consumed_names(subgraphs: List[Dict[str, Any]]): for input_name in all_input_names: for subgraph in reversed(subgraphs): if input_name in subgraph.input_names: - subgraph.consumed_names.append(input_name) + subgraph.consumed_names.add(input_name) break else: assert False diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 1325712de..ec34b2273 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -4,14 +4,14 @@ import torch import torch.utils.data.dataloader import tqdm +from compressed_tensors.utils import get_execution_device from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pipelines.sequential.cache import IntermediatesCache from llmcompressor.pipelines.sequential.helpers import ( infer_sequential_targets, trace_subgraphs, ) -from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context __all__ = ["run_pipeline"] @@ -33,7 +33,7 @@ def run_pipeline( In order to reduce memory requirements 1. Data is passed through each subgraph with batch size 1 - 2. The intermediate activations between each subgraph are offloaded onto the CPU + 2. Intermediate activations between each subgraph are offloaded onto the CPU This pipeline requires that the model be tracable with respect to data from the data loader. This may be an issue for vision language models with vision datasets, @@ -50,55 +50,38 @@ def run_pipeline( with calibration_forward_context(model): # prepare intermediates cache - desc = "Preparing intermediates cache" - batch_intermediates = [ - apply_pad_mask_to_batch(batch) if "attention_mask" in batch else batch - for batch in tqdm.tqdm(dataloader, desc=desc) - ] - batch_outputs = [None for _ in range(len(dataloader))] + model_device = get_execution_device(model) + intermediates = IntermediatesCache.from_dataloader(dataloader, model_device) + model_outputs = [dict() for _ in range(len(dataloader))] num_subgraphs = len(subgraphs) - for index, subgraph in enumerate(subgraphs): + for subgraph_index, subgraph in enumerate(subgraphs): # prepare tqdm description texts - uncomp_desc = f"({index + 1}/{num_subgraphs}): Calibrating" - comp_desc = f"({index + 1}/{num_subgraphs}): Propagate" + calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" + prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagate" # compile subgraph forward function - code = subgraph.graph.python_code("self") - exec(code.src, code.globals) - forward_function = code.globals.get("forward") + forward_function = subgraph.compile_forward() if propagate_error: # do an preliminary pass to trigger modifier hooks - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=uncomp_desc): - intermediates = batch_intermediates[batch_index] - inputs = { - input_name: intermediates[input_name] - for input_name in subgraph.input_names - } - inputs = tensors_to_device(inputs, subgraph.input_device) + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_index, subgraph.input_names) forward_function(model, **inputs) + del inputs # if using propagate_error, then this pass does not trigger modifier hooks # and is only used for capturing intermediates # otherwise, this pass triggers modifier hooks and captures intermediates with HooksMixin.disable_hooks() if propagate_error else nullcontext(): - desc = comp_desc if propagate_error else uncomp_desc + desc = prop_desc if propagate_error else calib_desc for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): - intermediates = batch_intermediates[batch_index] - - inputs = { - input_name: intermediates[input_name] - for input_name in subgraph.input_names - } - inputs = tensors_to_device(inputs, subgraph.input_device) + inputs = intermediates.fetch(batch_index, subgraph.input_names) subgraph_output = forward_function(model, **inputs) - subgraph_output = tensors_to_device(subgraph_output, "cpu") - - for consumed_name in subgraph.consumed_names: - del intermediates[consumed_name] + del inputs - if index < len(subgraphs) - 1: - intermediates.update(subgraph_output) + if subgraph_index < len(subgraphs) - 1: + intermediates.update(batch_index, subgraph_output) + intermediates.delete(batch_index, subgraph.consumed_names) else: - batch_outputs[batch_index] = subgraph_output + model_outputs[batch_index] = subgraph_output diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index ee764b9f8..8dd6a0cb6 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -137,8 +137,7 @@ def hessian_memory_requirements(model: torch.nn.Module) -> int: max_total_hessian_elems = max(total_hessian_elems.values()) overall_max_column_size = max(max_column_size.values()) bytes_per_weight = 32 // 8 # hessians are float32 - # allocate enough space for out of place operations - inverse_reserved = overall_max_column_size * overall_max_column_size * 2 + inverse_reserved = overall_max_column_size * overall_max_column_size return (max_total_hessian_elems + inverse_reserved) * bytes_per_weight From b102bf51c411a5884d3a45a8572c1705d2e39cd7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 16:36:48 -0500 Subject: [PATCH 184/285] add missing cache file --- .../pipelines/sequential/cache.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 src/llmcompressor/pipelines/sequential/cache.py diff --git a/src/llmcompressor/pipelines/sequential/cache.py b/src/llmcompressor/pipelines/sequential/cache.py new file mode 100644 index 000000000..7d2defb74 --- /dev/null +++ b/src/llmcompressor/pipelines/sequential/cache.py @@ -0,0 +1,94 @@ +import warnings +from dataclasses import dataclass +from typing import Any, Dict, List, Union + +import torch +import tqdm + +from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch + + +@dataclass +class IntermediateValue: + value: Any + device: Union[torch.device, None] + + +class IntermediatesCache: + batch_intermediates: List[Dict[str, IntermediateValue]] + offload_device: torch.device + + def __init__( + self, + batch_intermediates: List[Dict[str, IntermediateValue]], + offload_device: torch.device, + ): + self.batch_intermediates = batch_intermediates + self.offload_device = offload_device + + @classmethod + def from_dataloader( + cls, + dataloader: torch.utils.data.DataLoader, + model_device: torch.device, + mask_padding: bool = True, + offload_device: torch.device = "cpu", + ): + batch_intermediates = [] + for batch in tqdm.tqdm(dataloader, desc="Preparing intermediates cache"): + if mask_padding and "attention_mask" in batch: + batch = apply_pad_mask_to_batch(batch) + batch = { + key: IntermediateValue(value=value, device=model_device) + for key, value in batch.items() + } + batch_intermediates.append(batch) + + return cls(batch_intermediates, offload_device) + + def fetch(self, batch_index: int, input_names: List[str]) -> Dict[str, Any]: + intermediates = self.batch_intermediates[batch_index] + + return { + key: self._onload_value(subgraph_input) + for key, subgraph_input in intermediates.items() + if key in input_names + } + + def update(self, batch_index: int, outputs: Dict[str, Any]): + # assume that all model intermediates are tensors + assert (isinstance(value, torch.Tensor) for value in outputs.values()) + + intermediates = { + key: self._offload_value(value) for key, value in outputs.items() + } + + self.batch_intermediates[batch_index].update(intermediates) + + def delete(self, batch_index: int, consumed_names: List[str]): + intermediates = self.batch_intermediates[batch_index] + for name in consumed_names: + del intermediates[name] + + def _onload_value(self, intermediate: IntermediateValue) -> Any: + value = intermediate.value + device = intermediate.device + + if device is not None: + if isinstance(value, torch.Tensor): + return value.to(device=device) + else: + raise NotImplementedError("Intermediates") + + else: + return value + + def _offload_value(self, value: Any) -> IntermediateValue: + if isinstance(value, torch.Tensor): + return IntermediateValue( + value=value.to(device=self.offload_device), device=value.device + ) + + else: + warnings.warn(f"Offloading not implemented for type {type(value)}.") + return IntermediateValue(value=value, device=None) From 229d3ae532aacbe4f5aa39d659b8e50a86728379 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 16:48:54 -0500 Subject: [PATCH 185/285] make mllama tracable --- examples/multimodal_vision/mllama.py | 9 +- examples/multimodal_vision/pixtral.py | 6 +- examples/multimodal_vision/pixtral_large.py | 6 + .../modifiers/quantization/gptq/base.py | 6 +- .../pipelines/sequential/helpers.py | 4 +- .../pipelines/sequential/pipeline.py | 6 +- .../finetune/data/data_helpers.py | 4 + .../transformers/tracing/__init__.py | 7 + .../tracing/llava.py | 89 +- .../transformers/tracing/mllama.py | 2567 +++++++++++++++++ 10 files changed, 2673 insertions(+), 31 deletions(-) create mode 100644 examples/multimodal_vision/pixtral_large.py create mode 100644 src/llmcompressor/transformers/tracing/__init__.py rename src/llmcompressor/{pytorch => transformers}/tracing/llava.py (72%) create mode 100644 src/llmcompressor/transformers/tracing/mllama.py diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 3d1ba24af..02fca32d2 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -1,14 +1,16 @@ import os import torch -from transformers import AutoProcessor, MllamaForConditionalGeneration +from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import TracableMllamaForConditionalGeneration # Load model. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -model = MllamaForConditionalGeneration.from_pretrained( +model = TracableMllamaForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype="auto" ) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) @@ -16,7 +18,7 @@ # Oneshot arguments DATASET_ID = "flickr30k" DATASET_SPLIT = "test[:512]" -NUM_CALIBRATION_SAMPLES = 1 +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 @@ -58,6 +60,7 @@ def data_collator(batch): trust_remote_code_model=True, output_dir=save_path, data_collator=data_collator, + # data_collator=DataCollator(), ) processor.save_pretrained(save_path) diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index dee690480..600a146c4 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -4,8 +4,9 @@ from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.pytorch.tracing.llava import TracableLlavaForConditionalGeneration +from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration # Load model. model_id = "mgoin/pixtral-12b" @@ -56,7 +57,8 @@ def data_collator(batch): num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, output_dir=save_path, - data_collator=data_collator, + # data_collator=data_collator, + data_collator=DataCollator(), ) model.save_pretrained(save_path) diff --git a/examples/multimodal_vision/pixtral_large.py b/examples/multimodal_vision/pixtral_large.py new file mode 100644 index 000000000..ebef5047d --- /dev/null +++ b/examples/multimodal_vision/pixtral_large.py @@ -0,0 +1,6 @@ +from transformers import AutoProcessor + +processor = AutoProcessor.from_pretrained( + "mistral-community/Pixtral-Large-Instruct-2411" +) +processor = AutoProcessor.from_pretrained("mgoin/pixtral-12b") diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index bfe1f9715..1cfef2feb 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -27,6 +27,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.basic import run_pipeline as run_basic from llmcompressor.pipelines.sequential import run_pipeline as run_sequential +from llmcompressor.transformers import tracing from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import ( get_layers, @@ -220,7 +221,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._update_size = len(state.data.calib) # infer pipeline - if "pixel_values" not in state.data.calib.dataset.column_names: + if ( + state.model.__class__.__name__ in tracing.__all__ + or "pixel_values" not in state.data.calib.dataset.column_names + ): try: run_sequential( state.model, diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index b35f309e5..b0136a53b 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -85,7 +85,7 @@ def trace_subgraphs( return subgraphs -def get_tracer(model: Module, targets: List[Module]) -> HFTracer: +def get_tracer(model: Module, sequential_targets: List[Module]) -> HFTracer: offloaded_modules = set( module for module in model.modules() if has_offloaded_params(module) ) @@ -94,7 +94,7 @@ class PiecewiseTracer(HFTracer): # Treat as leaf, skip tracing inside this module def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: return ( - module in targets + module in sequential_targets or module in offloaded_modules or super().is_leaf_module(module, module_qualified_name) ) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index ec34b2273..10a40cfc2 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -77,11 +77,11 @@ def run_pipeline( desc = prop_desc if propagate_error else calib_desc for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): inputs = intermediates.fetch(batch_index, subgraph.input_names) - subgraph_output = forward_function(model, **inputs) + output = forward_function(model, **inputs) del inputs if subgraph_index < len(subgraphs) - 1: - intermediates.update(batch_index, subgraph_output) + intermediates.update(batch_index, output) intermediates.delete(batch_index, subgraph.consumed_names) else: - model_outputs[batch_index] = subgraph_output + model_outputs[batch_index] = output diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index ee3a47f88..797c1fb40 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -107,6 +107,10 @@ def format_calibration_data( if accelerator: calib_dataloader = accelerator.prepare(calib_dataloader) + # sample = next(iter(calib_dataloader)) + # print({k: [torch.tensor(s).shape for s in sample[k]] for k in sample}) + # breakpoint() + return calib_dataloader diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py new file mode 100644 index 000000000..8ca3d9777 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -0,0 +1,7 @@ +from .llava import TracableLlavaForConditionalGeneration +from .mllama import TracableMllamaForConditionalGeneration + +__all__ = [ + "TracableLlavaForConditionalGeneration", + "TracableMllamaForConditionalGeneration", +] diff --git a/src/llmcompressor/pytorch/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py similarity index 72% rename from src/llmcompressor/pytorch/tracing/llava.py rename to src/llmcompressor/transformers/tracing/llava.py index 6ee76f299..85a418e79 100644 --- a/src/llmcompressor/pytorch/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -1,4 +1,20 @@ # flake8: noqa +# coding=utf-8 +# Copyright 2023 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Llava model.""" + from functools import wraps from typing import List, Optional, Tuple, Union @@ -27,13 +43,23 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy @@ -42,7 +68,9 @@ def forward( ) if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) if pixel_values is not None and inputs_embeds is not None: raise ValueError( @@ -56,7 +84,7 @@ def forward( # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - + # NOT TRACABLE, instead always use legacy_processing = False # legacy_processing = ( # (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length @@ -79,17 +107,23 @@ def forward( ) # prefill stage vs decoding stage (legacy behavior copied) if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels + inputs_embeds, attention_mask, labels, position_ids = ( + self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + ) + cache_position = torch.arange( + attention_mask.shape[1], device=attention_mask.device ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + batch_index, non_attended_tokens = torch.where( + first_layer_past_key_value.float().sum(-2) == 0 + ) # Get the target length target_length = input_ids.shape[1] @@ -111,9 +145,13 @@ def forward( # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + attention_mask = torch.cat( + (extended_attention_mask, attention_mask[:, -target_length:]), dim=1 + ) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + cache_position = torch.arange( + attention_mask.shape[1], device=attention_mask.device + )[-target_length:] # @raushan retain only the new behavior after v4.47 elif image_features is not None: @@ -121,7 +159,7 @@ def forward( n_image_features = image_features.shape[0] * image_features.shape[1] # NOT TRACABLE, instead always use n_image_tokens != n_image_features = False - #if n_image_tokens != n_image_features: + # if n_image_tokens != n_image_features: if False: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -132,8 +170,12 @@ def forward( .expand_as(inputs_embeds) .to(inputs_embeds.device) ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + image_features = image_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features + ) outputs = self.language_model( attention_mask=attention_mask, @@ -156,16 +198,23 @@ def forward( if attention_mask is not None: # we use the input attention mask to shift the logits and labels, because it is 2D. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( + logits.device + ) + shift_logits = logits[..., :-1, :][ + shift_attention_mask.to(logits.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][ + shift_attention_mask.to(labels.device) != 0 + ].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1).to(shift_logits.device), ) if not return_dict: @@ -179,4 +228,4 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, - ) \ No newline at end of file + ) diff --git a/src/llmcompressor/transformers/tracing/mllama.py b/src/llmcompressor/transformers/tracing/mllama.py new file mode 100644 index 000000000..301517618 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/mllama.py @@ -0,0 +1,2567 @@ +# flake8: noqa +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.fx import wrap +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.models.mllama.configuration_mllama import ( + MllamaConfig, + MllamaTextConfig, + MllamaVisionConfig, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +logger = logging.get_logger(__name__) + + +@wrap # NOT TRACABLE, wrap this function +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave( + num_vision_tokens, dim=3 + ) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value) + .any(dim=-1) + .type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape( + batch_size, max_num_tiles * target_length, 1 + ) + attention_mask = ( + attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min + ) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, config: MllamaVisionConfig, is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size + ) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, config: MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size, + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size + ) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision +class MllamaVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionAttention(nn.Module): + def __init__(self, config: MllamaVisionConfig): + super().__init__() + + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + + self.q_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.embed_dim, bias=False + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + ) -> torch.Tensor: + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view( + batch_size, q_seq_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose( + 1, 2 + ) + value = value.view( + batch_size, kv_seq_len, self.num_heads, self.head_dim + ).transpose(1, 2) + + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query.dtype) + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return output, attn_weights + + +class MllamaVisionSdpaAttention(MllamaVisionAttention): + # Adapted from MllamaVisionAttention + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + ) -> torch.Tensor: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + if output_attentions: + logger.warning_once( + "MllamaModel is using MllamaVisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_state=hidden_state, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + return output, None + + +MLLAMA_VISION_ATTENTION_CLASSES = { + "eager": MllamaVisionAttention, + "sdpa": MllamaVisionSdpaAttention, +} + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__(self, config: MllamaVisionConfig, is_gated: bool = False): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MLLAMA_VISION_ATTENTION_CLASSES[config._attn_implementation]( + config + ) + self.mlp = MllamaVisionMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + self.hidden_size, eps=config.norm_eps + ) + + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state, attn_weights = self.self_attn( + hidden_state, attention_mask=attention_mask + ) + if self.is_gated: + hidden_state = self.gate_attn.tanh() * hidden_state + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + if self.is_gated: + hidden_state = self.gate_ffn.tanh() * hidden_state + hidden_state = residual + hidden_state + + outputs = (hidden_state,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class MllamaVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`MllamaEncoderLayer`]. + + Args: + config: MllamaConfig + """ + + def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)] + ) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class MllamaTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MllamaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Optional[MllamaTextConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + if past_key_value is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + key_states, value_states = past_key_value.update( + key_states, + value_states, + self.layer_idx, + {"cache_position": cache_position}, + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MllamaTextCrossSdpaAttention(MllamaTextCrossAttention): + """ + Mllama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MllamaTextCrossAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MllamaTextCrossAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MllamaModel is using MllamaTextCrossSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if past_key_value is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + key_states, value_states = past_key_value.update( + key_states, + value_states, + self.layer_idx, + {"cache_position": cache_position}, + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if attention_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MllamaTextSelfAttention(nn.Module): + def __init__(self, config: MllamaTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.rope_theta = config.rope_theta + self.layer_idx = layer_idx + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MllamaTextSelfSdpaAttention(MllamaTextSelfAttention): + # Adapted from MllamaTextSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MllamaModel is using MllamaTextSelfSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value + + +MLLAMA_TEXT_CROSS_ATTENTION_CLASSES = { + "eager": MllamaTextCrossAttention, + "sdpa": MllamaTextCrossSdpaAttention, +} +MLLAMA_TEXT_ATTENTION_CLASSES = { + "eager": MllamaTextSelfAttention, + "sdpa": MllamaTextSelfSdpaAttention, +} + + +# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText +class MllamaTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + # Ignore copy + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer +class MllamaSelfAttentionDecoderLayer(nn.Module): + def __init__(self, config: MllamaTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MLLAMA_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = MllamaTextMLP(config) + self.input_layernorm = MllamaTextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MllamaTextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.layer_idx = layer_idx + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.45 + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: + super().__init__() + self.layer_idx = layer_idx + self.cross_attn = MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ + config._attn_implementation + ](config, layer_idx=layer_idx) + + self.input_layernorm = MllamaTextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = MllamaTextMLP(config) + self.post_attention_layernorm = MllamaTextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + attention_mask: torch.Tensor, + full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if full_text_row_masked_out_mask is not None: + hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +class MllamaRotaryEmbedding(nn.Module): + def __init__(self, config: MllamaTextConfig, device=None): + super().__init__() + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len + and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class MllamaPreTrainedModel(PreTrainedModel): + config_class = MllamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "MllamaVisionEncoderLayer", + "MllamaCrossAttentionDecoderLayer", + "MllamaSelfAttentionDecoderLayer", + ] + _supports_cache_class = True + _supports_static_cache = ( + False # static cache cannot have different shapes for each layer + ) + _supports_sdpa = True + _supports_quantized_cache = True + + def _init_weights(self, module): + std = self.config.get_text_config().initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Parameter): + module.data.normal_(mean=0.0, std=std) + elif isinstance(module, MllamaVisionModel): + nn.init.normal_(module.class_embedding.data, std=std) + elif isinstance(module, MllamaPrecomputedPositionEmbedding): + nn.init.normal_(module.embedding.data, std=std) + elif isinstance(module, MllamaVisionEncoderLayer) and module.is_gated: + nn.init.normal_(module.gate_attn.data, std=std) + nn.init.normal_(module.gate_ffn.data, std=std) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not using_static_cache + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +MLLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MllamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +MLLAMA_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`MllamaImageProcessor.__call__`] for details ([]`MllamaProcessor`] uses + [`MllamaImageProcessor`] for processing images). + aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): + Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`: + + - 1 for tiles that are **not masked**, + - 0 for tiles that are **masked**. + aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): + Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image. + These ids correspond to indices in the model's list of supported aspect ratios, offset by 1. + + For example, if the model supports aspect ratios [[1, 1], [1, 2], [2, 1]]: + - An image with aspect ratio [1, 1] would have ID 1 + - An image with aspect ratio [1, 2] would have ID 2 + - An image with aspect ratio [2, 1] would have ID 3 + + The id 0 is reserved for padding (i.e., no image). + + If an image has aspect ratio [1, 2], that means it was split into 2 tiles horizontally, and its `aspect_ratio_id` would be 2. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +MLLAMA_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): + Cross-attention mask to control the interaction between text tokens and image tiles. + This 4D tensor defines which image tiles each text token should attend to. + + For each text token (in seq_length): + - 1 indicates the token **should attend** to the corresponding image tile + - 0 indicates the token **should not attend** to the corresponding image tile + cross_attention_states (`torch.FloatTensor`, *optional*): + Output of the vision model, used for cross-attention. This tensor contains the processed image features that + the language model will attend to. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +MLLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`MllamaImageProcessor.__call__`] for details ([]`MllamaProcessor`] uses + [`MllamaImageProcessor`] for processing images). + aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): + Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`: + + - 1 for tiles that are **not masked**, + - 0 for tiles that are **masked**. + aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): + Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image. + These ids correspond to indices in the model's list of supported aspect ratios, offset by 1. + + For example, if the model supports aspect ratios [[1, 1], [1, 2], [2, 1]]: + - An image with aspect ratio [1, 1] would have ID 1 + - An image with aspect ratio [1, 2] would have ID 2 + - An image with aspect ratio [2, 1] would have ID 3 + + The id 0 is reserved for padding (i.e., no image). + + If an image has aspect ratio [1, 2], that means it was split into 2 tiles horizontally, and its `aspect_ratio_id` would be 2. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): + Cross-attention mask to control the interaction between text tokens and image tiles. + This 4D tensor defines which image tiles each text token should attend to. + + For each text token (in seq_length): + - 1 indicates the token **should attend** to the corresponding image tile + - 0 indicates the token **should not attend** to the corresponding image tile + cross_attention_states (`torch.FloatTensor`, *optional*): + Output of the vision model, used for cross-attention. This tensor contains the processed image features that + the language model will attend to. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """The Mllama Vision Model which consists of two vision encoders.""", + MLLAMA_START_DOCSTRING, +) +class MllamaVisionModel(MllamaPreTrainedModel): + config_class = MllamaVisionConfig + base_model_prefix = "vision_model" + + def __init__(self, config: MllamaVisionConfig): + super().__init__(config) + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + config, is_gated=True + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + config, is_gated=True + ) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder( + config, config.num_hidden_layers, is_gated=False + ) + self.global_transformer = MllamaVisionEncoder( + config, config.num_global_layers, is_gated=True + ) + + self.post_init() + + def get_input_embeddings(self): + """ + This function is used to fetch the first embedding layer to activate grads on inputs. + """ + return self.patch_embedding + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + @add_start_docstrings_to_model_forward(MLLAMA_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutput, config_class="MllamaVisionConfig" + ) + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + r""" + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, MllamaVisionModel + + >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" + >>> model = MllamaVisionModel.from_pretrained(checkpoint) + >>> processor = AutoProcessor.from_pretrained(checkpoint) + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + + >>> output = model(**inputs) + + >>> print(output.last_hidden_state.shape) + torch.Size([1, 1, 4, 1025, 7680]) + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( + pixel_values.shape + ) + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, -1, dim + ) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1 + ) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + dim, + ) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = [ + output[1][i] for i in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + + if output_hidden_states: + hidden_states = tuple(all_intermediate_hidden_states) + tuple( + global_output[1] + ) + else: + hidden_states = None + + if output_attentions: + # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range + global_attn = ( + tuple(global_output[2]) + if output_hidden_states + else tuple(global_output[1]) + ) + attentions = tuple(output[2]) + global_attn + else: + attentions = None + + if not return_dict: + return tuple( + v for v in [hidden_state, hidden_states, attentions] if v is not None + ) + + return BaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@add_start_docstrings( + """The Mllama Text Model which consists of transformer with self and cross attention layers.""", + MLLAMA_START_DOCSTRING, +) +class MllamaTextModel(MllamaPreTrainedModel): + config_class = MllamaTextConfig + base_model_prefix = "language_model.model" + + def __init__(self, config: MllamaTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding( + config.vocab_size + 8, config.hidden_size, self.padding_idx + ) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_idx in range(config.num_hidden_layers): + if layer_idx in self.cross_attention_layers: + layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx)) + else: + layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx)) + + self.layers = nn.ModuleList(layers) + self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MllamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MLLAMA_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutputWithPast, config_class="MllamaTextConfig" + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, MllamaTextModel + + >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" + >>> model = MllamaTextModel.from_pretrained(checkpoint) + >>> processor = AutoProcessor.from_pretrained(checkpoint) + + >>> text = "<|image|>If I had to write a haiku for this one" + >>> inputs = processor(text=text, return_tensors="pt") + + >>> output = model(**inputs) + + >>> print(output.last_hidden_state.shape) + torch.Size([1, 13, 4096]) + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # For text-only path we should skip cross attention layers. + # Let's check if the layer is cross attention layer and if we have cross attention states + # or cached cross attention states. + is_cross_attention_layer = idx in self.cross_attention_layers + is_cross_attention_cache_empty = past_key_values is None or ( + past_key_values is not None and past_key_values.get_seq_length(idx) == 0 + ) + + if ( + is_cross_attention_layer + and cross_attention_states is None + and is_cross_attention_cache_empty + ): + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + cross_attention_states, + cross_attention_mask, + causal_mask, + full_text_row_masked_out_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + """The Mllama Text Model with a language modeling head on top.""", + MLLAMA_START_DOCSTRING, +) +class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): + config_class = MllamaTextConfig + _supports_static_cache = True # only the LLM without cross attn can do compile + base_model_prefix = "language_model" + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config.get_text_config()) + self.text_config = config.get_text_config() + self.vocab_size = self.text_config.vocab_size + self.model = MllamaTextModel._from_config(self.text_config) + self.lm_head = nn.Linear( + self.text_config.hidden_size, self.vocab_size, bias=False + ) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MllamaForCausalLM + + >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") + + >>> prompt = "If I had to write a haiku, it would be:" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) + >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(result) + If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. + I love the idea of snowflakes gently falling, each one + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """The Mllama model which consists of a vision encoder and a language model.""", + MLLAMA_START_DOCSTRING, +) +class TracableMllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): + _supports_quantized_cache = ( + False # quant cache not supported in encoder-decoder setting + ) + + def __init__(self, config: MllamaConfig): + super().__init__(config) + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) + + self.vision_model = MllamaVisionModel._from_config(config.vision_config) + self.language_model = MllamaForCausalLM._from_config(config.text_config) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class="MllamaConfig" + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, MllamaForConditionalGeneration + + >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" + >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) + >>> processor = AutoProcessor.from_pretrained(checkpoint) + + >>> prompt = "<|image|>If I had to write a haiku for this one" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> output = model.generate(**inputs, max_new_tokens=15) + + >>> prompt_len = inputs.input_ids.shape[-1] + >>> generated_ids = output[:, prompt_len:] + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + >>> print(generated_text) + [', it would be:.\\nA stop sign in Chinatown.\\n'] + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError( + "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" + ) + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector( + cross_attention_states + ).reshape(-1, cross_attention_states.shape[-2], self.hidden_size) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = ( + _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[ + :, :, cache_position + ] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, + past_key_values=None, + use_cache=False, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = { + "input_ids": input_ids.clone(memory_format=torch.contiguous_format), + "inputs_embeds": None, + } + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cross_attention_mask": cross_attention_mask, + } + ) + + # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios + # to compute image hidden states, otherwise they are cached within each cross attn layer + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["aspect_ratio_ids"] = aspect_ratio_ids + model_inputs["aspect_ratio_mask"] = aspect_ratio_mask + + return model_inputs + + def _update_model_kwargs_for_generation( + self, outputs, model_kwargs, is_encoder_decoder, **kwargs + ): + cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + # add cross-attn mask for new token + if cross_attention_mask_prev is not None: + model_kwargs["cross_attention_mask"] = torch.cat( + [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], + dim=1, + ) + return model_kwargs From 4e0b118a5072d26ff02b6074edc776a11b016db0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 17:20:19 -0500 Subject: [PATCH 186/285] write using comprehesion --- .../pipelines/sequential/cache.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/cache.py b/src/llmcompressor/pipelines/sequential/cache.py index 7d2defb74..511e458d5 100644 --- a/src/llmcompressor/pipelines/sequential/cache.py +++ b/src/llmcompressor/pipelines/sequential/cache.py @@ -5,8 +5,6 @@ import torch import tqdm -from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch - @dataclass class IntermediateValue: @@ -34,15 +32,20 @@ def from_dataloader( mask_padding: bool = True, offload_device: torch.device = "cpu", ): - batch_intermediates = [] - for batch in tqdm.tqdm(dataloader, desc="Preparing intermediates cache"): - if mask_padding and "attention_mask" in batch: - batch = apply_pad_mask_to_batch(batch) - batch = { - key: IntermediateValue(value=value, device=model_device) + batch_intermediates = [ + { + key: ( + IntermediateValue( + value=value.masked_fill_(batch["attention_mask"] == 0, 0), + device=model_device, + ) + if mask_padding and key == "input_ids" + else IntermediateValue(value=value, device=model_device) + ) for key, value in batch.items() } - batch_intermediates.append(batch) + for batch in tqdm.tqdm(dataloader, desc="Preparing intermediates cache") + ] return cls(batch_intermediates, offload_device) From 7dc4d2a79d793400f55b51ce34eefe2ace258185 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 17:24:14 -0500 Subject: [PATCH 187/285] fix hessian requirements --- src/llmcompressor/transformers/compression/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 8dd6a0cb6..7c839c5a7 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -236,7 +236,7 @@ def calculate_offload_device_map( reserved_memory = 0 if reserve_for_hessians: - reserved_memory = hessian_memory_requirements(dummy_model) * 2 + reserved_memory = hessian_memory_requirements(dummy_model) reserved_memory += quantization_memory_requirement(dummy_model) memory_limits = { From 377b2a45359b89cedcf719f17772c19674fc006a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 17:36:30 -0500 Subject: [PATCH 188/285] implement offloading for tuple --- .../pipelines/sequential/cache.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/cache.py b/src/llmcompressor/pipelines/sequential/cache.py index 511e458d5..ce8374954 100644 --- a/src/llmcompressor/pipelines/sequential/cache.py +++ b/src/llmcompressor/pipelines/sequential/cache.py @@ -8,7 +8,7 @@ @dataclass class IntermediateValue: - value: Any + value: Union[torch.Tensor, "IntermediateValue", Any] device: Union[torch.device, None] @@ -77,12 +77,12 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any: value = intermediate.value device = intermediate.device - if device is not None: - if isinstance(value, torch.Tensor): - return value.to(device=device) - else: - raise NotImplementedError("Intermediates") - + if isinstance(value, torch.Tensor): + return value.to(device=device) + + elif isinstance(value, tuple): + return tuple(self._onload_value(v) for v in value) + else: return value @@ -91,6 +91,20 @@ def _offload_value(self, value: Any) -> IntermediateValue: return IntermediateValue( value=value.to(device=self.offload_device), device=value.device ) + + if isinstance(value, tuple): + return IntermediateValue( + value=tuple(self._offload_value(v) for v in value), device=None + ) + + # if isinstance(value, MutableMapping): + # offloaded_value = + # for key in value: + # self._offload_value(value[key]) + + # return IntermediateValue( + # value=self._offload_value(v) for k, v in value.items()), device=None + # ) else: warnings.warn(f"Offloading not implemented for type {type(value)}.") From adb162750b3936323e0e5502bbd380a2470f21e8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 17:37:57 -0500 Subject: [PATCH 189/285] add save --- examples/multimodal_vision/mllama.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 02fca32d2..84a106bb3 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -17,8 +17,8 @@ # Oneshot arguments DATASET_ID = "flickr30k" -DATASET_SPLIT = "test[:512]" -NUM_CALIBRATION_SAMPLES = 512 +DATASET_SPLIT = "test[:1]" +NUM_CALIBRATION_SAMPLES = 1 MAX_SEQUENCE_LENGTH = 2048 @@ -42,6 +42,7 @@ def data_collator(batch): targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], + dampening_frac=100.0, ), ] @@ -64,6 +65,7 @@ def data_collator(batch): ) processor.save_pretrained(save_path) +model.save_pretrained(save_path) # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") From ab3fc812fceb47634109afd05405e8e222e96fdf Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 17:38:26 -0500 Subject: [PATCH 190/285] change num samples --- examples/multimodal_vision/mllama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 84a106bb3..87dc21d5e 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -17,8 +17,8 @@ # Oneshot arguments DATASET_ID = "flickr30k" -DATASET_SPLIT = "test[:1]" -NUM_CALIBRATION_SAMPLES = 1 +DATASET_SPLIT = "test[:512]" +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 From 1bf683e0bc99aa6c031f29c167c82049053915b7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 17:52:55 -0500 Subject: [PATCH 191/285] implement intermediates offloading for dataclasses --- .../pipelines/sequential/cache.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/cache.py b/src/llmcompressor/pipelines/sequential/cache.py index ce8374954..3222aec0b 100644 --- a/src/llmcompressor/pipelines/sequential/cache.py +++ b/src/llmcompressor/pipelines/sequential/cache.py @@ -1,5 +1,5 @@ import warnings -from dataclasses import dataclass +from dataclasses import asdict, dataclass, is_dataclass from typing import Any, Dict, List, Union import torch @@ -79,10 +79,16 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any: if isinstance(value, torch.Tensor): return value.to(device=device) - + + elif is_dataclass(value): + for key, v in asdict(value): + setattr(value, key, self._onload_value(v)) + + return value + elif isinstance(value, tuple): return tuple(self._onload_value(v) for v in value) - + else: return value @@ -91,20 +97,17 @@ def _offload_value(self, value: Any) -> IntermediateValue: return IntermediateValue( value=value.to(device=self.offload_device), device=value.device ) - + + elif is_dataclass(value): + for key, v in asdict(value): + setattr(value, key, self._offload_value(v)) + + return IntermediateValue(value=value, device=None) + if isinstance(value, tuple): return IntermediateValue( value=tuple(self._offload_value(v) for v in value), device=None ) - - # if isinstance(value, MutableMapping): - # offloaded_value = - # for key in value: - # self._offload_value(value[key]) - - # return IntermediateValue( - # value=self._offload_value(v) for k, v in value.items()), device=None - # ) else: warnings.warn(f"Offloading not implemented for type {type(value)}.") From b75fe15f18089d455709c081bb278ea1070ff4ee Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 21:51:47 -0500 Subject: [PATCH 192/285] wrap ignore but do not treat as sequential target --- examples/multimodal_vision/mllama.py | 1 - examples/multimodal_vision/pixtral.py | 6 +-- .../modifiers/quantization/gptq/base.py | 10 ++-- .../pipelines/sequential/helpers.py | 52 +++++++++---------- .../pipelines/sequential/pipeline.py | 8 +-- 5 files changed, 33 insertions(+), 44 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 87dc21d5e..44df76e7c 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -42,7 +42,6 @@ def data_collator(batch): targets="Linear", scheme="W8A8", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], - dampening_frac=100.0, ), ] diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 600a146c4..cf3ae5cb3 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -11,7 +11,7 @@ # Load model. model_id = "mgoin/pixtral-12b" model = TracableLlavaForConditionalGeneration.from_pretrained( - model_id, device_map="balanced", torch_dtype="auto" + model_id, device_map="auto", torch_dtype="auto" ) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) @@ -57,8 +57,8 @@ def data_collator(batch): num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, output_dir=save_path, - # data_collator=data_collator, - data_collator=DataCollator(), + data_collator=data_collator, + # data_collator=DataCollator(), ) model.save_pretrained(save_path) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 1cfef2feb..63d657a51 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -29,11 +29,7 @@ from llmcompressor.pipelines.sequential import run_pipeline as run_sequential from llmcompressor.transformers import tracing from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - qat_active, -) +from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active __all__ = ["GPTQModifier"] @@ -213,8 +209,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer sequential targets if self.sequential_targets is None: self.sequential_targets = get_no_split_params(state.model) - elif isinstance(self.sequential_targets, str): - self.sequential_targets = get_layers(self.sequential_targets, self.model) + if isinstance(self.sequential_targets, str): + self.sequential_targets = [self.sequential_targets] # infer update size if self._update_size is None: diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index b0136a53b..29e454809 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -7,11 +7,14 @@ from compressed_tensors.quantization import find_name_or_class_matches from torch.fx import Graph, GraphModule, Node from torch.nn import Module +from transformers import PreTrainedModel from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.utils.helpers import calibration_forward_context +__all__ = ["trace_subgraphs", "Subgraph"] + @dataclass class Subgraph: @@ -25,34 +28,18 @@ def compile_forward(self): return code.globals.get("forward") -__all__ = ["infer_sequential_targets", "trace_subgraphs"] - - -def infer_sequential_targets( - model: Module, sequential_targets: List[str], ignore: List[str] -) -> Set[Module]: - """ - Future: infer from recipe - - List of modules which are guaranteed to be split into different partitions and - whose inner operations will not be traced - """ - targets_names = sequential_targets + ignore - - sequential_targets = set( - module - for name, module in model.named_modules() - if find_name_or_class_matches(name, module, targets_names) - ) - - return sequential_targets - - def trace_subgraphs( - model: Module, sample_input: Dict[str, Any], targets: Set[Module] + model: PreTrainedModel, + sample_input: Dict[str, Any], + sequential_targets: List[str], + ignore: List[str], ) -> List[Subgraph]: + # find modules + sequential_targets = match_modules(model, sequential_targets) + ignore = match_modules(model, ignore) + # initialize arguments - tracer = get_tracer(model, targets) + tracer = get_tracer(model, sequential_targets, ignore) concrete_args = populate_concrete_args(model, sample_input) # trace @@ -78,14 +65,16 @@ def trace_subgraphs( graph.device = model.device # perform subgraph partition - partitions = topological_partition(graph, targets) + partitions = topological_partition(graph, sequential_targets) subgraphs = partition_graph(model, partitions) trace_consumed_names(subgraphs) return subgraphs -def get_tracer(model: Module, sequential_targets: List[Module]) -> HFTracer: +def get_tracer( + model: Module, sequential_targets: Set[Module], ignore: Set[Module] +) -> HFTracer: offloaded_modules = set( module for module in model.modules() if has_offloaded_params(module) ) @@ -96,6 +85,7 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: return ( module in sequential_targets or module in offloaded_modules + or module in ignore or super().is_leaf_module(module, module_qualified_name) ) @@ -261,3 +251,11 @@ def check_assumption(graph: Graph) -> bool: return False return True + + +def match_modules(model: Module, target_names: List[str]) -> Set[Module]: + return set( + module + for name, module in model.named_modules() + if find_name_or_class_matches(name, module, target_names) + ) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 10a40cfc2..45d10e4f5 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -8,10 +8,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.sequential.cache import IntermediatesCache -from llmcompressor.pipelines.sequential.helpers import ( - infer_sequential_targets, - trace_subgraphs, -) +from llmcompressor.pipelines.sequential.helpers import trace_subgraphs from llmcompressor.utils.helpers import calibration_forward_context __all__ = ["run_pipeline"] @@ -42,8 +39,7 @@ def run_pipeline( """ # trace subgraphs sample_input = next(iter(dataloader)) - targets = infer_sequential_targets(model, sequential_targets, ignore) - subgraphs = trace_subgraphs(model, sample_input, targets) + subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) # FUTURE: apply recipe to model # initialize(recipe, model) From aa4a23deb9f1cde60514cd1a6f7551235985b103 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 23:30:32 -0500 Subject: [PATCH 193/285] tracable pixtral/mistral --- examples/multimodal_vision/pixtral.py | 1 + .../pipelines/sequential/cache.py | 12 +- .../pipelines/sequential/helpers.py | 10 + .../transformers/tracing/__init__.py | 2 + .../transformers/tracing/llava.py | 30 +- .../transformers/tracing/mistral.py | 1531 +++++++++++++++++ 6 files changed, 1583 insertions(+), 3 deletions(-) create mode 100644 src/llmcompressor/transformers/tracing/mistral.py diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index cf3ae5cb3..bb3160052 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -40,6 +40,7 @@ def data_collator(batch): scheme="W8A8", ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], sequential_targets=["MistralDecoderLayer"], + dampening_frac=100.0, ), ] diff --git a/src/llmcompressor/pipelines/sequential/cache.py b/src/llmcompressor/pipelines/sequential/cache.py index 3222aec0b..335955100 100644 --- a/src/llmcompressor/pipelines/sequential/cache.py +++ b/src/llmcompressor/pipelines/sequential/cache.py @@ -36,7 +36,7 @@ def from_dataloader( { key: ( IntermediateValue( - value=value.masked_fill_(batch["attention_mask"] == 0, 0), + value=cls._mask_padding(value, batch["attention_mask"]), device=model_device, ) if mask_padding and key == "input_ids" @@ -112,3 +112,13 @@ def _offload_value(self, value: Any) -> IntermediateValue: else: warnings.warn(f"Offloading not implemented for type {type(value)}.") return IntermediateValue(value=value, device=None) + + @staticmethod + def _mask_padding( + input_ids: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: + if attention_mask.dim() == 4: + # some attention masks, such as those from pixtral, are are 4d + attention_mask = attention_mask[0, 0, 0].unsqueeze(0) + + return input_ids.masked_fill_(torch.logical_not(attention_mask), 0) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 29e454809..e243ee4c4 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -6,8 +6,10 @@ from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches from torch.fx import Graph, GraphModule, Node +from torch.fx.proxy import Argument from torch.nn import Module from transformers import PreTrainedModel +from transformers.configuration_utils import PretrainedConfig from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -80,6 +82,14 @@ def get_tracer( ) class PiecewiseTracer(HFTracer): + def create_arg(self, a: Any) -> Argument: + if isinstance(a, PretrainedConfig): + kwargs = {k: self.create_arg(v) for k, v in a.to_dict().items()} + return self.create_node("call_function", a.__class__, (), kwargs) + + else: + return super().create_arg(a) + # Treat as leaf, skip tracing inside this module def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: return ( diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 8ca3d9777..bc2356a45 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,7 +1,9 @@ from .llava import TracableLlavaForConditionalGeneration from .mllama import TracableMllamaForConditionalGeneration +from .mistral import TracableMistralForCausalLM __all__ = [ "TracableLlavaForConditionalGeneration", "TracableMllamaForConditionalGeneration", + "TracableMistralForCausalLM", ] diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index 85a418e79..b5b5376bc 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -19,11 +19,37 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import LlavaForConditionalGeneration -from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast, logger +from transformers import AutoModel, AutoModelForCausalLM, LlavaForConditionalGeneration +from transformers.models.llava.configuration_llava import LlavaConfig +from transformers.models.llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaMultiModalProjector, + logger, +) +from transformers.models.mistral.configuration_mistral import MistralConfig + +from .mistral import TracableMistralForCausalLM class TracableLlavaForConditionalGeneration(LlavaForConditionalGeneration): + def __init__(self, config: LlavaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + # NOT TRACABLE: Must use TracableMistralForCausalLM which wraps untracable function + if isinstance(config.text_config, MistralConfig): + self.language_model = TracableMistralForCausalLM(config.text_config) + else: + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) + self.post_init() + @wraps(LlavaForConditionalGeneration.forward) def forward( self, diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py new file mode 100644 index 000000000..08f8d32d7 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/mistral.py @@ -0,0 +1,1531 @@ +# flake8: noqa +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mistral model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.cache_utils import ( + Cache, + DynamicCache, + SlidingWindowCache, + StaticCache, +) +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.mistral.configuration_mistral import MistralConfig +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" +_CONFIG_FOR_DOC = "MistralConfig" + + +from torch.fx import wrap + + +@wrap +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: MistralConfig, + past_key_values: Cache, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`MistralConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if ( + not isinstance(past_key_values, SlidingWindowCache) + or sequence_length > target_length + ): + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + return causal_mask + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MistralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + # TODO(joao): add me back asap :) + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MistralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MistralFlashAttention2(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += cache_position[0] + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +# TODO(joao): add me back asap :) +class MistralSdpaAttention(MistralAttention): + """ + Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MistralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, +} + + +# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL +# TODO(joao): add me back asap :) +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MISTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MistralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralPreTrainedModel(PreTrainedModel): + config_class = MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MISTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices indicating the position of the input sequence tokens in the sequence. Unlike `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralModel(MistralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + MistralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._attn_implementation = config._attn_implementation + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + use_cache, + output_attentions, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_cache: bool, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = ( + attention_mask[:, -1].sum().item() != input_tensor.size()[0] + ) + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + +class TracableMistralForCausalLM(MistralPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Ensure tensors are on the same device + shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForSequenceClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ) + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + pooled_logits=pooled_logits, + config=self.config, + ) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForTokenClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Mistral Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model +class MistralForQuestionAnswering(MistralPreTrainedModel): + base_model_prefix = "model" + + # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mistral + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function( + start_logits, end_logits, start_positions, end_positions, **kwargs + ) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) From aa532b503836160a83ffac26f3746c1729c9b236 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 23:33:21 -0500 Subject: [PATCH 194/285] remove double saving --- examples/multimodal_vision/mllama.py | 3 --- examples/multimodal_vision/pixtral.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 44df76e7c..2fac52e6d 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -63,9 +63,6 @@ def data_collator(batch): # data_collator=DataCollator(), ) -processor.save_pretrained(save_path) -model.save_pretrained(save_path) - # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index bb3160052..a554f66b6 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -62,9 +62,6 @@ def data_collator(batch): # data_collator=DataCollator(), ) -model.save_pretrained(save_path) -processor.save_pretrained(save_path) - # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") From 19e4f971a81ecb8ad40dc2d7eb13b1e3ca61daee Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 12 Dec 2024 23:39:49 -0500 Subject: [PATCH 195/285] revert dampening frac --- examples/multimodal_vision/mllama.py | 3 ++- examples/multimodal_vision/pixtral.py | 4 ++-- src/llmcompressor/transformers/tracing/__init__.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 2fac52e6d..3f61532e2 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -4,7 +4,8 @@ from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.pytorch.data_collator import DataCollator + +# from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TracableMllamaForConditionalGeneration diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index a554f66b6..f311582a6 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -4,7 +4,8 @@ from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.pytorch.data_collator import DataCollator + +# from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration @@ -40,7 +41,6 @@ def data_collator(batch): scheme="W8A8", ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], sequential_targets=["MistralDecoderLayer"], - dampening_frac=100.0, ), ] diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index bc2356a45..88be0fe88 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,6 +1,6 @@ from .llava import TracableLlavaForConditionalGeneration -from .mllama import TracableMllamaForConditionalGeneration from .mistral import TracableMistralForCausalLM +from .mllama import TracableMllamaForConditionalGeneration __all__ = [ "TracableLlavaForConditionalGeneration", From f95b77fc0738dca87111bac3a1d98a876a420cd2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 13 Dec 2024 10:10:52 -0500 Subject: [PATCH 196/285] do not cache model outputs to save memory --- src/llmcompressor/pipelines/sequential/pipeline.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 45d10e4f5..74be738b6 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -48,7 +48,6 @@ def run_pipeline( # prepare intermediates cache model_device = get_execution_device(model) intermediates = IntermediatesCache.from_dataloader(dataloader, model_device) - model_outputs = [dict() for _ in range(len(dataloader))] num_subgraphs = len(subgraphs) for subgraph_index, subgraph in enumerate(subgraphs): @@ -64,7 +63,6 @@ def run_pipeline( for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): inputs = intermediates.fetch(batch_index, subgraph.input_names) forward_function(model, **inputs) - del inputs # if using propagate_error, then this pass does not trigger modifier hooks # and is only used for capturing intermediates @@ -74,10 +72,7 @@ def run_pipeline( for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): inputs = intermediates.fetch(batch_index, subgraph.input_names) output = forward_function(model, **inputs) - del inputs if subgraph_index < len(subgraphs) - 1: intermediates.update(batch_index, output) intermediates.delete(batch_index, subgraph.consumed_names) - else: - model_outputs[batch_index] = output From 2d890dbc491d9f589aed29e935f4eefa1305b44b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 13 Dec 2024 14:57:58 -0500 Subject: [PATCH 197/285] fix dataclass case, add tests --- .../pipelines/sequential/cache.py | 18 +- .../pipelines/sequential/test_cache.py | 169 ++++++++++++++++++ 2 files changed, 182 insertions(+), 5 deletions(-) create mode 100644 tests/llmcompressor/pipelines/sequential/test_cache.py diff --git a/src/llmcompressor/pipelines/sequential/cache.py b/src/llmcompressor/pipelines/sequential/cache.py index 335955100..baa976528 100644 --- a/src/llmcompressor/pipelines/sequential/cache.py +++ b/src/llmcompressor/pipelines/sequential/cache.py @@ -1,5 +1,5 @@ import warnings -from dataclasses import asdict, dataclass, is_dataclass +from dataclasses import dataclass, fields, is_dataclass from typing import Any, Dict, List, Union import torch @@ -81,14 +81,18 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any: return value.to(device=device) elif is_dataclass(value): - for key, v in asdict(value): - setattr(value, key, self._onload_value(v)) + for field in fields(value): # `asdict` is recursive, not applicable here + v = getattr(value, field.name) + setattr(value, field.name, self._onload_value(v)) return value elif isinstance(value, tuple): return tuple(self._onload_value(v) for v in value) + elif isinstance(value, (int, str, float, bool)) or value is None: + return value + else: return value @@ -99,8 +103,9 @@ def _offload_value(self, value: Any) -> IntermediateValue: ) elif is_dataclass(value): - for key, v in asdict(value): - setattr(value, key, self._offload_value(v)) + for field in fields(value): # `asdict` is recursive, not applicable here + v = getattr(value, field.name) + setattr(value, field.name, self._offload_value(v)) return IntermediateValue(value=value, device=None) @@ -109,6 +114,9 @@ def _offload_value(self, value: Any) -> IntermediateValue: value=tuple(self._offload_value(v) for v in value), device=None ) + if isinstance(value, (int, str, float, bool)) or value is None: + return IntermediateValue(value=value, device=None) + else: warnings.warn(f"Offloading not implemented for type {type(value)}.") return IntermediateValue(value=value, device=None) diff --git a/tests/llmcompressor/pipelines/sequential/test_cache.py b/tests/llmcompressor/pipelines/sequential/test_cache.py new file mode 100644 index 000000000..253645294 --- /dev/null +++ b/tests/llmcompressor/pipelines/sequential/test_cache.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass + +import pytest +import torch +from torch.utils.data import DataLoader, StackDataset + +from llmcompressor.pipelines.sequential.cache import ( + IntermediatesCache, + IntermediateValue, +) + + +@pytest.fixture +def sample_dataloader(): + # Create sample input tensors + input_ids = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 1, 0]], dtype=torch.bool) + + # Create dataset and dataloader + dataset = StackDataset(input_ids=input_ids, attention_mask=attention_mask) + return DataLoader(dataset, batch_size=2) + + +@pytest.fixture +def sample_cache(sample_dataloader): + return IntermediatesCache.from_dataloader( + dataloader=sample_dataloader, + model_device=torch.device("cpu"), + mask_padding=True, + offload_device=torch.device("cpu"), + ) + + +def test_initialization(sample_dataloader): + cache = IntermediatesCache.from_dataloader( + dataloader=sample_dataloader, + model_device=torch.device("cpu"), + mask_padding=True, + ) + + assert isinstance(cache, IntermediatesCache) + assert len(cache.batch_intermediates) > 0 + assert isinstance(cache.batch_intermediates[0], dict) + + +def test_fetch_inputs(sample_cache): + fetched = sample_cache.fetch(0, ["input_ids", "attention_mask"]) + + assert isinstance(fetched, dict) + assert "input_ids" in fetched + assert "attention_mask" in fetched + assert isinstance(fetched["input_ids"], torch.Tensor) + assert isinstance(fetched["attention_mask"], torch.Tensor) + + +def test_update_intermediates(sample_cache): + new_outputs = { + "hidden_states": torch.randn(2, 4, 768), + "logits": torch.randn(2, 4, 1000), + } + + sample_cache.update(0, new_outputs) + + # Verify the updates were stored + assert "hidden_states" in sample_cache.batch_intermediates[0] + assert "logits" in sample_cache.batch_intermediates[0] + + +def test_delete_intermediates(sample_cache): + # First add some intermediates + new_outputs = { + "hidden_states": torch.randn(2, 4, 768), + "logits": torch.randn(2, 4, 1000), + } + sample_cache.update(0, new_outputs) + + # Then delete them + sample_cache.delete(0, ["hidden_states"]) + + assert "hidden_states" not in sample_cache.batch_intermediates[0] + assert "logits" in sample_cache.batch_intermediates[0] + + +def test_mask_padding(): + input_ids = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]]) + attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 1, 0]]) + + masked = IntermediatesCache._mask_padding(input_ids, attention_mask) + + # Check if padding tokens are properly masked + expected = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]]) + assert torch.equal(masked, expected) + + +def test_offload_and_onload_tensor(): + cache = IntermediatesCache([], torch.device("cpu")) + + # Test tensor offloading + original_tensor = torch.randn(2, 3).to("cpu") + offloaded = cache._offload_value(original_tensor) + + assert isinstance(offloaded, IntermediateValue) + assert isinstance(offloaded.value, torch.Tensor) + assert offloaded.device == original_tensor.device + + # Test tensor onloading + onloaded = cache._onload_value(offloaded) + assert torch.equal(onloaded, original_tensor) + + +@dataclass +class SampleDataclass: + a: torch.Tensor + b: int + + +def test_offload_and_onload_dataclass(): + cache = IntermediatesCache([], torch.device("cpu")) + + # Create a sample dataclass instance + sample_data = SampleDataclass(a=torch.randn(2, 3), b=42) + + # Test dataclass offloading + offloaded = cache._offload_value(sample_data) + assert isinstance(offloaded, IntermediateValue) + assert isinstance(offloaded.value, SampleDataclass) + assert isinstance(offloaded.value.a, IntermediateValue) + assert isinstance(offloaded.value.b, IntermediateValue) + + # Test dataclass onloading + onloaded = cache._onload_value(offloaded) + assert onloaded == sample_data + + +def test_4d_attention_mask(): + input_ids = torch.tensor([[1, 2, 3, 0]]) + attention_mask = torch.ones(1, 1, 1, 4) # 4D attention mask + + masked = IntermediatesCache._mask_padding(input_ids, attention_mask) + + # Check if the function handles 4D attention mask properly + expected = torch.tensor([[1, 2, 3, 0]]) + assert torch.equal(masked, expected) + + +def test_device_handling(sample_dataloader): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + cuda_device = torch.device("cuda") + cpu_device = torch.device("cpu") + + # Create a cache with GPU as model device and CPU as offload device + cache = IntermediatesCache.from_dataloader( + dataloader=sample_dataloader, + model_device=cuda_device, + offload_device=cpu_device, + ) + + # Add some GPU tensors + new_outputs = {"hidden_states": torch.randn(2, 3).to(cuda_device)} + cache.update(0, new_outputs) + + # Verify tensors are offloaded to CPU + assert cache.batch_intermediates[0]["hidden_states"].value.device.type == "cpu" + + # Verify tensors are loaded back to GPU when fetched + fetched = cache.fetch(0, ["hidden_states"]) + assert fetched["hidden_states"].device.type == "cuda" From 4a22032a29403b0777114d92217d2e0191156677 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 13 Dec 2024 20:15:07 +0000 Subject: [PATCH 198/285] Remove docstring --- src/llmcompressor/modifiers/quantization/gptq/base.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 63d657a51..c12dba096 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -38,16 +38,6 @@ class GPTQModifier(Modifier, HooksMixin): """ Modifier for applying the one-shot OBCQ algorithm to a model - Lifecycle: - - on_initialize - - initialize_compression() - - compressible_layers() - - LayerCompressor.pre_compress() - - apply_compression() - - run_calibration_forward() - - LayerCompressor.compress() - - LayerCompressor.post_compress() - - LayerCompressor.revert_layer_wrappers() | Sample yaml: | test_stage: | obcq_modifiers: From a71352ad22904cc2ccab96890039f90051959451 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 14 Dec 2024 15:19:29 -0500 Subject: [PATCH 199/285] move IntermediatesCache location --- .../pipelines/{sequential => }/cache.py | 29 +++++++++++-------- .../pipelines/sequential/pipeline.py | 4 +-- .../pipelines/{sequential => }/test_cache.py | 5 +--- 3 files changed, 20 insertions(+), 18 deletions(-) rename src/llmcompressor/pipelines/{sequential => }/cache.py (84%) rename tests/llmcompressor/pipelines/{sequential => }/test_cache.py (97%) diff --git a/src/llmcompressor/pipelines/sequential/cache.py b/src/llmcompressor/pipelines/cache.py similarity index 84% rename from src/llmcompressor/pipelines/sequential/cache.py rename to src/llmcompressor/pipelines/cache.py index baa976528..b4e8a440c 100644 --- a/src/llmcompressor/pipelines/sequential/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -1,6 +1,6 @@ import warnings from dataclasses import dataclass, fields, is_dataclass -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch import tqdm @@ -24,6 +24,11 @@ def __init__( self.batch_intermediates = batch_intermediates self.offload_device = offload_device + @classmethod + def empty(cls, num_batches: int, offload_device: torch.device): + batch_intermediates = [{} for _ in range(num_batches)] + return cls(batch_intermediates, offload_device) + @classmethod def from_dataloader( cls, @@ -49,27 +54,27 @@ def from_dataloader( return cls(batch_intermediates, offload_device) - def fetch(self, batch_index: int, input_names: List[str]) -> Dict[str, Any]: + def fetch( + self, batch_index: int, input_names: Optional[List[str]] = None + ) -> Dict[str, Any]: intermediates = self.batch_intermediates[batch_index] return { key: self._onload_value(subgraph_input) for key, subgraph_input in intermediates.items() - if key in input_names - } - - def update(self, batch_index: int, outputs: Dict[str, Any]): - # assume that all model intermediates are tensors - assert (isinstance(value, torch.Tensor) for value in outputs.values()) - - intermediates = { - key: self._offload_value(value) for key, value in outputs.items() + if input_names is None or key in input_names } + def update(self, batch_index: int, values: Dict[str, Any]): + intermediates = {k: self._offload_value(v) for k, v in values.items()} self.batch_intermediates[batch_index].update(intermediates) - def delete(self, batch_index: int, consumed_names: List[str]): + def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None): intermediates = self.batch_intermediates[batch_index] + + if consumed_names is None: + consumed_names = list(intermediates.keys()) + for name in consumed_names: del intermediates[name] diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 74be738b6..6dbd4d58d 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -7,7 +7,7 @@ from compressed_tensors.utils import get_execution_device from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.pipelines.sequential.cache import IntermediatesCache +from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.sequential.helpers import trace_subgraphs from llmcompressor.utils.helpers import calibration_forward_context @@ -73,6 +73,6 @@ def run_pipeline( inputs = intermediates.fetch(batch_index, subgraph.input_names) output = forward_function(model, **inputs) - if subgraph_index < len(subgraphs) - 1: + if subgraph_index < num_subgraphs - 1: intermediates.update(batch_index, output) intermediates.delete(batch_index, subgraph.consumed_names) diff --git a/tests/llmcompressor/pipelines/sequential/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py similarity index 97% rename from tests/llmcompressor/pipelines/sequential/test_cache.py rename to tests/llmcompressor/pipelines/test_cache.py index 253645294..71a72eb25 100644 --- a/tests/llmcompressor/pipelines/sequential/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -4,10 +4,7 @@ import torch from torch.utils.data import DataLoader, StackDataset -from llmcompressor.pipelines.sequential.cache import ( - IntermediatesCache, - IntermediateValue, -) +from llmcompressor.pipelines.cache import IntermediatesCache, IntermediateValue @pytest.fixture From 2d249a2fced0cc8f074e5b6599e1d65a285d72de Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 14 Dec 2024 15:38:10 -0500 Subject: [PATCH 200/285] add fake_sequential --- .../pipelines/fake_sequential/__init__.py | 2 + .../pipelines/fake_sequential/helpers.py | 99 +++++++++++++++++++ .../pipelines/fake_sequential/pipeline.py | 61 ++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 src/llmcompressor/pipelines/fake_sequential/__init__.py create mode 100644 src/llmcompressor/pipelines/fake_sequential/helpers.py create mode 100644 src/llmcompressor/pipelines/fake_sequential/pipeline.py diff --git a/src/llmcompressor/pipelines/fake_sequential/__init__.py b/src/llmcompressor/pipelines/fake_sequential/__init__.py new file mode 100644 index 000000000..fc60475ca --- /dev/null +++ b/src/llmcompressor/pipelines/fake_sequential/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline import run_pipeline diff --git a/src/llmcompressor/pipelines/fake_sequential/helpers.py b/src/llmcompressor/pipelines/fake_sequential/helpers.py new file mode 100644 index 000000000..483a5e336 --- /dev/null +++ b/src/llmcompressor/pipelines/fake_sequential/helpers.py @@ -0,0 +1,99 @@ +import contextlib +import inspect +from dataclasses import dataclass +from typing import Any, Dict, List, Set, Tuple + +import torch +import tqdm +from compressed_tensors.quantization import find_name_or_class_matches +from compressed_tensors.utils import get_execution_device +from torch.nn import Module +from torch.utils.data.dataloader import DataLoader + +from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.pytorch.utils.helpers import tensors_to_device +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["match_modules", "compute_first_layer_intermediates"] + + +def match_modules(model: Module, target_names: List[str]) -> List[Module]: + names_layers = [ + (name, module) + for name, module in model.named_modules() + if find_name_or_class_matches(name, module, target_names) + ] + + names_layers = sorted(names_layers, key=lambda name_layer: name_layer[0]) + return [layer for _name, layer in names_layers] + + +def compute_first_layer_intermediates( + model: Module, + layers: List[Module], + dataloader: DataLoader, + mask_padding: bool = True, +) -> IntermediatesCache: + model_device = get_execution_device(model) + intermediates = IntermediatesCache.empty(len(dataloader), torch.device("cpu")) + first_layer = layers[0] + signature = inspect.signature(first_layer.forward) + + with calibration_forward_context(model), early_stop_hook(first_layer): + desc = "Preparing intermediates cache" + for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)): + batch = apply_pad_mask_to_batch(batch) if mask_padding else batch + batch = tensors_to_device(batch, model_device) + + try: + model(**batch) + except EarlyStopException as exception: + layer_args = args_to_kwargs(exception._args, signature) + assert not set(layer_args.keys()) & set(exception._kwargs.keys()) + layer_args.update(exception._kwargs) + + intermediates.update(batch_index, layer_args) + else: + raise ValueError( + "Attempted to capture first layer intermediates, but " + "EarlyStopException was not raised" + ) + + return intermediates + + +def to_next_layer_kwargs(args: Tuple[Any, ...], next_layer: Module) -> Dict[str, Any]: + signature = inspect.signature(next_layer.forward) + return args_to_kwargs(args, signature) + + +def args_to_kwargs( + args: Tuple[Any, ...], signature: inspect.Signature +) -> Dict[str, Any]: + return {name: arg for name, arg in zip(signature.parameters.keys(), args)} + + +@contextlib.contextmanager +def early_stop_hook(module: Module): + def trigger_early_stop_fn(module, args, kwargs): + raise EarlyStopException(_args=args, _kwargs=kwargs) + + handle = module.register_forward_pre_hook(trigger_early_stop_fn, with_kwargs=True) + + yield + + handle.remove() + + +@dataclass +class EarlyStopException(Exception): + """ + Note: this is exception different from the exception defined in + llmcompressor.modifiers.utils.pytorch_helpers, and will eventually replace + + Attribute names `args` and `kwargs` are reserved for `dataclass` + """ + + _args: Tuple[Any, ...] + _kwargs: Dict[str, Any] diff --git a/src/llmcompressor/pipelines/fake_sequential/pipeline.py b/src/llmcompressor/pipelines/fake_sequential/pipeline.py new file mode 100644 index 000000000..24ff7135e --- /dev/null +++ b/src/llmcompressor/pipelines/fake_sequential/pipeline.py @@ -0,0 +1,61 @@ +from contextlib import nullcontext +from typing import List + +import torch +import torch.utils.data.dataloader +import tqdm +from compressed_tensors.utils import get_execution_device + +from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.pipelines.fake_sequential.helpers import ( + compute_first_layer_intermediates, + match_modules, + to_next_layer_kwargs, +) +from llmcompressor.utils.helpers import calibration_forward_context + +__all__ = ["run_pipeline"] + + +def run_pipeline( + model: torch.nn.Module, + sequential_targets: List[str], # FUTURE: replace with recipe inference + dataloader: torch.utils.data.DataLoader, + propagate_error: bool, +): + """ """ + # find layers + layers = match_modules(model, sequential_targets) + + # FUTURE: apply recipe to model + # initialize(recipe, model) + + with calibration_forward_context(model): + intermediates = compute_first_layer_intermediates(model, layers, dataloader) + + num_layers = len(layers) + for layer_index, layer in enumerate(layers): + # prepare tqdm description texts + calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" + prop_desc = f"({layer_index + 1}/{num_layers}): Propagate" + + if propagate_error: + # do an preliminary pass to trigger modifier hooks + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_index) + layer(**inputs) + + # if using propagate_error, then this pass does not trigger modifier hooks + # and is only used for capturing intermediates + # otherwise, this pass triggers modifier hooks and captures intermediates + with HooksMixin.disable_hooks() if propagate_error else nullcontext(): + desc = prop_desc if propagate_error else calib_desc + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): + inputs = intermediates.fetch(batch_index) + output = layer(**inputs) + output = to_next_layer_kwargs(output, layers[layer_index + 1]) + + if layer_index < num_layers - 1: + intermediates.delete(batch_index) + intermediates.update(batch_index, output) From 995cb2dd68607c70ab36339552f502fedc4ef577 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 14 Dec 2024 16:05:26 -0500 Subject: [PATCH 201/285] rename fake_sequential to layer_sequential --- .../__init__.py | 0 .../helpers.py | 13 +++++++------ .../pipeline.py | 8 +++----- 3 files changed, 10 insertions(+), 11 deletions(-) rename src/llmcompressor/pipelines/{fake_sequential => layer_sequential}/__init__.py (100%) rename src/llmcompressor/pipelines/{fake_sequential => layer_sequential}/helpers.py (93%) rename src/llmcompressor/pipelines/{fake_sequential => layer_sequential}/pipeline.py (88%) diff --git a/src/llmcompressor/pipelines/fake_sequential/__init__.py b/src/llmcompressor/pipelines/layer_sequential/__init__.py similarity index 100% rename from src/llmcompressor/pipelines/fake_sequential/__init__.py rename to src/llmcompressor/pipelines/layer_sequential/__init__.py diff --git a/src/llmcompressor/pipelines/fake_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py similarity index 93% rename from src/llmcompressor/pipelines/fake_sequential/helpers.py rename to src/llmcompressor/pipelines/layer_sequential/helpers.py index 483a5e336..8004539f7 100644 --- a/src/llmcompressor/pipelines/fake_sequential/helpers.py +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -1,7 +1,7 @@ import contextlib import inspect from dataclasses import dataclass -from typing import Any, Dict, List, Set, Tuple +from typing import Any, Dict, List, Tuple import torch import tqdm @@ -15,7 +15,7 @@ from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context -__all__ = ["match_modules", "compute_first_layer_intermediates"] +__all__ = ["match_modules", "capture_first_layer_intermediates", "to_next_layer_kwargs"] def match_modules(model: Module, target_names: List[str]) -> List[Module]: @@ -29,7 +29,7 @@ def match_modules(model: Module, target_names: List[str]) -> List[Module]: return [layer for _name, layer in names_layers] -def compute_first_layer_intermediates( +def capture_first_layer_intermediates( model: Module, layers: List[Module], dataloader: DataLoader, @@ -81,9 +81,10 @@ def trigger_early_stop_fn(module, args, kwargs): handle = module.register_forward_pre_hook(trigger_early_stop_fn, with_kwargs=True) - yield - - handle.remove() + try: + yield + finally: + handle.remove() @dataclass diff --git a/src/llmcompressor/pipelines/fake_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py similarity index 88% rename from src/llmcompressor/pipelines/fake_sequential/pipeline.py rename to src/llmcompressor/pipelines/layer_sequential/pipeline.py index 24ff7135e..aa5c57e2a 100644 --- a/src/llmcompressor/pipelines/fake_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -4,12 +4,10 @@ import torch import torch.utils.data.dataloader import tqdm -from compressed_tensors.utils import get_execution_device from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.pipelines.cache import IntermediatesCache -from llmcompressor.pipelines.fake_sequential.helpers import ( - compute_first_layer_intermediates, +from llmcompressor.pipelines.layer_sequential.helpers import ( + capture_first_layer_intermediates, match_modules, to_next_layer_kwargs, ) @@ -32,7 +30,7 @@ def run_pipeline( # initialize(recipe, model) with calibration_forward_context(model): - intermediates = compute_first_layer_intermediates(model, layers, dataloader) + intermediates = capture_first_layer_intermediates(model, layers, dataloader) num_layers = len(layers) for layer_index, layer in enumerate(layers): From e4bca342d46bbf278050b2545baa7b7a550b48b1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 14 Dec 2024 16:05:36 -0500 Subject: [PATCH 202/285] pipeline inference --- .../modifiers/quantization/gptq/base.py | 49 ++++++++++++------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c12dba096..cb21d5e90 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,5 +1,4 @@ import contextlib -import traceback import warnings from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -26,8 +25,10 @@ from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.basic import run_pipeline as run_basic +from llmcompressor.pipelines.layer_sequential import ( + run_pipeline as run_layer_sequential, +) from llmcompressor.pipelines.sequential import run_pipeline as run_sequential -from llmcompressor.transformers import tracing from llmcompressor.utils.metric_logging import CompressionLogger from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active @@ -207,33 +208,43 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._update_size = len(state.data.calib) # infer pipeline - if ( - state.model.__class__.__name__ in tracing.__all__ - or "pixel_values" not in state.data.calib.dataset.column_names - ): + try: + run_sequential( + state.model, + self.sequential_targets, + self.ignore, + state.data.calib, + propagate_error=True, + ) + return True + + # failure to trace + except torch.fx.proxy.TraceError: + model_name = state.model.__class__.__name__ + column_names = state.data.calib.dataset.column_names + warnings.warn( + f"Failed to trace {model_name} with dataset {column_names}. " + "Falling back to layer_sequential pipeline" + ) + try: - run_sequential( + run_layer_sequential( state.model, self.sequential_targets, - self.ignore, state.data.calib, propagate_error=True, ) + return True - except torch.fx.proxy.TraceError: - print(traceback.format_exc()) + # failure to match kwargs + except TypeError: warnings.warn( - "Failed to trace model graph, using non-sequential " - "pipeline with `offload_hessians = True`" + f"{model_name} does not conform to layer-wise architecture " + "assumptions. Falling back to basic pipeline, which requires extra " + "memory and may result in decreased accuracy" ) - self.offload_hessians = True run_basic(state.model, state.data.calib) - - else: - warnings.warn("Cannot use sequential pipeline with vision datasets") - run_basic(state.model, state.data.calib) - - return True + return True def on_finalize(self, state: "State", **kwargs) -> bool: """ From 4a046a54de75a3305cf2184c48b9e483c6414402 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 14 Dec 2024 16:22:14 -0500 Subject: [PATCH 203/285] update docstrings --- .../pipelines/layer_sequential/pipeline.py | 16 ++++++++++++- .../pipelines/sequential/pipeline.py | 24 ++++++++++--------- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index aa5c57e2a..56f03204d 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -22,7 +22,21 @@ def run_pipeline( dataloader: torch.utils.data.DataLoader, propagate_error: bool, ): - """ """ + """ + Run a layer-wise sequential data pipeline. + 1. Layers are identified according to `sequential_targets` + 2. A hook is attached to the first layer. This hook raises an exception which is + then caught and used to capture the input arguments to the first layer + 3. The inputs to the first layer are used to calibrate the first layer, and the + output of the previous layer is used as inputs to calibrate the next layer + + This pipeline requires that the model have distinct layers defined in its + architecture and that the outputs of the previous layer are exactly the inputs + to the next layer. This is violated by encoder-decoder architectures among others. + + If your model architecture violates these assumptions, consider using the sequential + pipeline (see llmcompressor.pipelines.sequential) + """ # find layers layers = match_modules(model, sequential_targets) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 6dbd4d58d..b1be72ddb 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -22,20 +22,22 @@ def run_pipeline( propagate_error: bool, ): """ - Run a sequential data pipeline. First, the model is partitioned into subgraphs - according to `sequential_targets`. Then, data passes through each subgraph - sequentially. If `propagate_error` is enabled, then data is passed through each - subgraph twice, once to trigger calibration hooks, then a second time in order to - capture activations after quantization has occurred through the hooks. - - In order to reduce memory requirements - 1. Data is passed through each subgraph with batch size 1 - 2. Intermediate activations between each subgraph are offloaded onto the CPU + Run a sequential data pipeline. + 1. The model is partitioned into subgraphs according to `sequential_targets` + 2. Data passes through each subgraph sequentially. If `propagate_error` is enabled, + then data is passed through each subgraph twice, once to trigger calibration + hooks, then a second time in order to capture activations after quantization + has occurred through the hooks. + 3. The intermediate activations between each subgraph are cached and offloaded to + the cpu between each batch in order to save memory This pipeline requires that the model be tracable with respect to data from the data loader. This may be an issue for vision language models with vision datasets, - due to specialized input processing in the model. In the event that tracing fails, - a torch.fx.proxy.TraceError will be raised. + due to specialized input processing in the model. + + In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A model + can be made tracable by wrapping the untracable functions (see + llmcompressor.transformers.tracing) """ # trace subgraphs sample_input = next(iter(dataloader)) From f24a2afcf025507dd5c7e12c757e2fda92821142 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 14 Dec 2024 16:24:13 -0500 Subject: [PATCH 204/285] fix last layer bug --- src/llmcompressor/pipelines/layer_sequential/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 56f03204d..b659b93b9 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -66,8 +66,8 @@ def run_pipeline( for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): inputs = intermediates.fetch(batch_index) output = layer(**inputs) - output = to_next_layer_kwargs(output, layers[layer_index + 1]) if layer_index < num_layers - 1: + output = to_next_layer_kwargs(output, layers[layer_index + 1]) intermediates.delete(batch_index) intermediates.update(batch_index, output) From 691bac40c4eebd80c9f36150c9cfe9a577cb42b7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 14 Dec 2024 16:30:36 -0500 Subject: [PATCH 205/285] better inference --- .../modifiers/quantization/gptq/base.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index cb21d5e90..398b1d486 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -218,15 +218,13 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) return True - # failure to trace - except torch.fx.proxy.TraceError: - model_name = state.model.__class__.__name__ - column_names = state.data.calib.dataset.column_names - warnings.warn( - f"Failed to trace {model_name} with dataset {column_names}. " - "Falling back to layer_sequential pipeline" - ) + except Exception as exception: + if isinstance(exception, torch.fx.proxy.TraceError): + model_name = state.model.__class__.__name__ + column_names = state.data.calib.dataset.column_names + warnings.warn(f"Failed to trace {model_name} with {column_names}") + warnings.warn("Falling back to layer_sequential pipeline") try: run_layer_sequential( state.model, From 1e15d3e33229be164320d4a9d267f3c4d4a57059 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 14 Dec 2024 16:34:48 -0500 Subject: [PATCH 206/285] even better inference --- .../modifiers/quantization/gptq/base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 398b1d486..c36209ce3 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -208,6 +208,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._update_size = len(state.data.calib) # infer pipeline + model_name = state.model.__class__.__name__ + column_names = state.data.calib.dataset.column_names try: run_sequential( state.model, @@ -220,8 +222,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: except Exception as exception: if isinstance(exception, torch.fx.proxy.TraceError): - model_name = state.model.__class__.__name__ - column_names = state.data.calib.dataset.column_names warnings.warn(f"Failed to trace {model_name} with {column_names}") warnings.warn("Falling back to layer_sequential pipeline") @@ -234,12 +234,13 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) return True - # failure to match kwargs - except TypeError: + except Exception as exception: + if isinstance(exception, TypeError): + warnings.warn(f"{model_name} fails layer-wise assumptions") + warnings.warn( - f"{model_name} does not conform to layer-wise architecture " - "assumptions. Falling back to basic pipeline, which requires extra " - "memory and may result in decreased accuracy" + "Falling back to basic pipeline, which requires extra memory and " + "may result in decreased accuracy" ) run_basic(state.model, state.data.calib) return True From a4744d93f3981695f6b10e6d6fbb8b927601d509 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 14:44:50 -0500 Subject: [PATCH 207/285] do now throw warning for calibration with training Signed-off-by: Kyle Sayers --- .../finetune/data/data_helpers.py | 27 ++++++++++--------- .../finetune/data/test_dataset_helpers.py | 13 ++++++++- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 9f14518f5..4968cd410 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -127,17 +127,18 @@ def make_dataset_splits( ) if do_eval: eval_split = _get_split_with_fallbacks( - datasets, "evaluation", ["validation", "test"], strict=True + datasets, "evaluation", ["validation"], ["test"], strict=True ) if do_predict: predict_split = _get_split_with_fallbacks( - datasets, "prediction", ["test", "validation"], strict=True + datasets, "prediction", ["test"], ["validation"], strict=True ) if do_oneshot: calib_split = _get_split_with_fallbacks( datasets, "oneshot", - ["calibration", "train", "test", "validation"], + ["calibration", "train"], + ["test", "validation"], strict=False, ) @@ -259,27 +260,27 @@ def do_transform(candidate: str) -> bool: def _get_split_with_fallbacks( datasets: Dict[str, DatasetType], task: str, - fallbacks: List[str], + preferred: List[str], + fallbacks: List[str] = [], strict: bool = True, ) -> DatasetType: - assert len(fallbacks) > 0 if len(datasets) <= 0: raise ValueError("Cannot get retrieve data from dataset with no splits") - # check first choice - first_choice = fallbacks[0] - if first_choice in datasets: - return datasets[first_choice] + # check preferred names (without warning) + for pref in preferred: + if pref in datasets: + return datasets[pref] - # last fallback is first available split + # fallback to the first available dataset if all else fails if not strict: fallbacks.append(next(iter(datasets.keys()))) - # check fallbacks - for fallback in fallbacks[1:]: + # check fallbacks (with warning) + for fallback in fallbacks: if fallback in datasets: warnings.warn( - f"{task} expects a {first_choice} dataset split, " + f"{task} expects one of {preferred} dataset split, " f"falling back to {fallback}" ) return datasets[fallback] diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 5229ea735..5e4f71d7b 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -72,7 +72,18 @@ def test_datasets_fallbacks(): assert split_ds.get("test").ds_name == "test_ds" assert split_ds.get("calibration").ds_name == "test_ds" - # oneshot will take any dataset + # oneshot takes train without warning + mock_datasets = {"train": Mock(ds_name="train_ds", column_names=[])} + split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) + assert split_ds.get("calibration").ds_name == "train_ds" + + # oneshot takes test with warning + mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])} + with pytest.warns(UserWarning): + split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) + assert split_ds.get("calibration").ds_name == "test_ds" + + # oneshot takes custom splits with warning mock_datasets = {"custom_split": Mock(ds_name="custom_ds", column_names=[])} with pytest.warns(UserWarning): split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) From 9617e53bcd360a0cb6579352ffc0a5cf945611f4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 14:51:09 -0500 Subject: [PATCH 208/285] add information about how to silence warning Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/data_helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 4968cd410..a0c7bb81d 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -264,6 +264,7 @@ def _get_split_with_fallbacks( fallbacks: List[str] = [], strict: bool = True, ) -> DatasetType: + assert len(preferred) > 0 if len(datasets) <= 0: raise ValueError("Cannot get retrieve data from dataset with no splits") @@ -281,7 +282,8 @@ def _get_split_with_fallbacks( if fallback in datasets: warnings.warn( f"{task} expects one of {preferred} dataset split, " - f"falling back to {fallback}" + f"falling back to {fallback}. Use " + f'splits={{"{preferred[0]}": "{fallback}"}} to silence this warning' ) return datasets[fallback] From 3b4cac11ff3e8add85ebfd812aadaade153c4436 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 14:52:40 -0500 Subject: [PATCH 209/285] nice Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/data_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index a0c7bb81d..e8f30a5ff 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -283,7 +283,7 @@ def _get_split_with_fallbacks( warnings.warn( f"{task} expects one of {preferred} dataset split, " f"falling back to {fallback}. Use " - f'splits={{"{preferred[0]}": "{fallback}"}} to silence this warning' + f'`splits={{"{preferred[0]}": "{fallback}"}}` to silence this warning' ) return datasets[fallback] From f53a3dd857de69d8ba640973424181251bb94844 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 15:02:27 -0500 Subject: [PATCH 210/285] remove unnecessary warning silencing Signed-off-by: Kyle Sayers --- tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py b/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py index db2634673..0ef7f872d 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_sparsity.py @@ -12,7 +12,6 @@ @pytest.mark.integration -@pytest.mark.filterwarnings("ignore::UserWarning") @parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestSparsities(unittest.TestCase): model = None @@ -60,7 +59,6 @@ def tearDown(self): # TODO: @Satrat and @dsikka, revisit if we want these nightly or weekly @requires_gpu @pytest.mark.integration -@pytest.mark.filterwarnings("ignore::UserWarning") @parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) class TestSparsitiesGPU(unittest.TestCase): model = None From fd151e4a59b538c057326f52c07268f01cf61778 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 15:13:38 -0500 Subject: [PATCH 211/285] add unmerged thing --- src/llmcompressor/transformers/finetune/runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index fc6276bd0..0841305d8 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -56,6 +56,7 @@ def __init__( self.datasets = {} self.trainer = None + self.processor = None self.parent_output_dir = self._training_args.output_dir self._output_dir = self._training_args.output_dir From d1d42deefc9408ada20b822a68c4a6d2bd93d9c0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 17:26:49 -0500 Subject: [PATCH 212/285] fix deleted columns --- .../transformers/finetune/data/base.py | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index ea3c71e7b..24126b374 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -272,10 +272,10 @@ def map( **kwargs, ) -> Union[Dataset, IterableDataset]: """ - Wrapper function around Dataset.map and IterableDataset.map + Wrapper function around Dataset.map and IterableDataset.map. - 1. Clears invalid parameters in the case where streaming is enabled - 2. Skips removing columns which were already removed after mapping + If the dataset is streaming (in the case of IterableDataset), non-applicable + arguments are ignored and the dataset features are resolved """ if isinstance(dataset, IterableDataset): # remove arguments that don't apply to streaming @@ -289,19 +289,4 @@ def map( if isinstance(dataset, IterableDataset): dataset = dataset._resolve_features() - # remove columns which are present, skip removing those which are not - if remove_columns is not None: - if isinstance(remove_columns, str): - remove_columns = [remove_columns] - - dataset_column_names = dataset.column_names - if isinstance(dataset_column_names, dict): - dataset_column_names = sum(dataset_column_names.values(), []) - if isinstance(remove_columns, dict): - remove_columns = sum(remove_columns.values(), []) - - dataset = dataset.remove_columns( - list(set(dataset_column_names) & set(remove_columns)) - ) - return dataset From 92151a11f25b6fcdfedaa4afe5e0800190dda781 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 21:29:32 -0500 Subject: [PATCH 213/285] handle dataset dict case --- .../transformers/finetune/data/base.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 24126b374..b69a7dfcd 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Union from compressed_tensors.registry import RegistryMixin from datasets import Dataset, IterableDataset @@ -93,7 +93,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset, self.preprocess, batched=False, - remove_columns=dataset.column_names, + remove_columns=get_column_names(dataset), num_proc=self.data_args.preprocessing_num_workers, desc="Preprocessing", ) @@ -101,7 +101,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: # rename and remove columns match processor kwargs dataset = self.rename_columns(dataset) - if "input_ids" not in dataset.column_names: + if "input_ids" not in get_column_names(dataset): # tokenize/ process dataset = self.map( dataset, @@ -110,7 +110,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: keep_in_memory=True, # bug occurs when not batched and not in memory, # subsequent ds.map calls are always batched, # regardless of `batched` argument - remove_columns=dataset.column_names, + remove_columns=get_column_names(dataset), num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, desc="Tokenizing", @@ -201,10 +201,8 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]: def rename_columns(self, dataset: DatasetType) -> DatasetType: # rename columns to match processor/tokenizer kwargs - if ( - self.data_args.text_column != "text" - and self.data_args.text_column in dataset.column_names - ): + column_names = get_column_names(dataset) + if self.data_args.text_column in column_names and "text" not in column_names: dataset = dataset.rename_column(self.data_args.text_column, "text") return dataset @@ -268,7 +266,6 @@ def map( self, dataset: Union[Dataset, IterableDataset], function: Callable[[Any], Any], - remove_columns: Optional[Union[str, List[str], Dict[str, List[str]]]] = None, **kwargs, ) -> Union[Dataset, IterableDataset]: """ @@ -290,3 +287,11 @@ def map( dataset = dataset._resolve_features() return dataset + + +def get_column_names(dataset: DatasetType) -> List[str]: + column_names = dataset.column_names + if isinstance(column_names, dict): + column_names = sum(column_names.values(), []) + + return column_names From 4c049db3cf7e95ec2f7bf57088ccf31c1a579978 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 16 Dec 2024 22:19:54 -0500 Subject: [PATCH 214/285] support torch.nn.Conv2d, silently ignore embeddings --- .../modifiers/quantization/gptq/base.py | 10 +++++++--- .../quantization/gptq/utils/gptq_quantize.py | 17 ++++++++++++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index c36209ce3..daca8c274 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -194,8 +194,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # register hooks for name, module in state.model.named_modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: - post_hook = partial(self.compress_module, name) - self.register_hook(module, post_hook, "forward") + # HACK: previously, embeddings were not quantized because they were not + # accessible by the layer compressor. For now, we manually ignore it, + # but in the FUTURE this should be ignored by the user + if not isinstance(module, torch.nn.Embedding): + post_hook = partial(self.compress_module, name) + self.register_hook(module, post_hook, "forward") # infer sequential targets if self.sequential_targets is None: @@ -291,7 +295,7 @@ def compress_module( with self._maybe_onload_hessian(module): self._hessians[module], self._num_samples[module] = accumulate_hessian( inp, - type(module), + module, self._hessians[module], self._num_samples[module], ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index d35fb9748..7efe21b94 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -1,6 +1,6 @@ import math from copy import copy -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, Optional, Tuple, Union import torch import transformers @@ -31,7 +31,7 @@ def make_empty_hessian( def accumulate_hessian( inp: torch.Tensor, - module_class: Type[torch.nn.Module], + module: torch.nn.Module, H: Optional[torch.Tensor] = None, num_samples: int = 1, ) -> Tuple[torch.Tensor, int]: @@ -42,11 +42,22 @@ def accumulate_hessian( num_added = inp.shape[0] # note this is the number of dataset samples, not # multiplied by the sequence length - if module_class in (torch.nn.Linear, transformers.Conv1D): + if isinstance(module, (torch.nn.Linear, transformers.Conv1D)): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() + if isinstance(module, torch.nn.Conv2d): + unfold = torch.nn.Unfold( + module.kernel_size, + dilation=module.dilation, + padding=module.padding, + stride=module.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + H *= num_samples / (num_samples + num_added) num_samples += num_added From 7667998d4f8e54bb2dce7be6f75a0d1d7b294e0d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 17 Dec 2024 03:59:22 +0000 Subject: [PATCH 215/285] handle columns better Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/base.py | 40 +++++++------------ 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index ea3c71e7b..b69a7dfcd 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Union from compressed_tensors.registry import RegistryMixin from datasets import Dataset, IterableDataset @@ -93,7 +93,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset, self.preprocess, batched=False, - remove_columns=dataset.column_names, + remove_columns=get_column_names(dataset), num_proc=self.data_args.preprocessing_num_workers, desc="Preprocessing", ) @@ -101,7 +101,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: # rename and remove columns match processor kwargs dataset = self.rename_columns(dataset) - if "input_ids" not in dataset.column_names: + if "input_ids" not in get_column_names(dataset): # tokenize/ process dataset = self.map( dataset, @@ -110,7 +110,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: keep_in_memory=True, # bug occurs when not batched and not in memory, # subsequent ds.map calls are always batched, # regardless of `batched` argument - remove_columns=dataset.column_names, + remove_columns=get_column_names(dataset), num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, desc="Tokenizing", @@ -201,10 +201,8 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]: def rename_columns(self, dataset: DatasetType) -> DatasetType: # rename columns to match processor/tokenizer kwargs - if ( - self.data_args.text_column != "text" - and self.data_args.text_column in dataset.column_names - ): + column_names = get_column_names(dataset) + if self.data_args.text_column in column_names and "text" not in column_names: dataset = dataset.rename_column(self.data_args.text_column, "text") return dataset @@ -268,14 +266,13 @@ def map( self, dataset: Union[Dataset, IterableDataset], function: Callable[[Any], Any], - remove_columns: Optional[Union[str, List[str], Dict[str, List[str]]]] = None, **kwargs, ) -> Union[Dataset, IterableDataset]: """ - Wrapper function around Dataset.map and IterableDataset.map + Wrapper function around Dataset.map and IterableDataset.map. - 1. Clears invalid parameters in the case where streaming is enabled - 2. Skips removing columns which were already removed after mapping + If the dataset is streaming (in the case of IterableDataset), non-applicable + arguments are ignored and the dataset features are resolved """ if isinstance(dataset, IterableDataset): # remove arguments that don't apply to streaming @@ -289,19 +286,12 @@ def map( if isinstance(dataset, IterableDataset): dataset = dataset._resolve_features() - # remove columns which are present, skip removing those which are not - if remove_columns is not None: - if isinstance(remove_columns, str): - remove_columns = [remove_columns] + return dataset - dataset_column_names = dataset.column_names - if isinstance(dataset_column_names, dict): - dataset_column_names = sum(dataset_column_names.values(), []) - if isinstance(remove_columns, dict): - remove_columns = sum(remove_columns.values(), []) - dataset = dataset.remove_columns( - list(set(dataset_column_names) & set(remove_columns)) - ) +def get_column_names(dataset: DatasetType) -> List[str]: + column_names = dataset.column_names + if isinstance(column_names, dict): + column_names = sum(column_names.values(), []) - return dataset + return column_names From f0eb640c8b949c3acaddacec30a439006034610d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 18 Dec 2024 16:32:34 +0000 Subject: [PATCH 216/285] fix tokenizer args Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/base.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index b69a7dfcd..c7eb33259 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,4 +1,6 @@ +import inspect from functools import cached_property +from inspect import _ParameterKind as Kind from typing import Any, Callable, Dict, List, Union from compressed_tensors.registry import RegistryMixin @@ -93,7 +95,6 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset, self.preprocess, batched=False, - remove_columns=get_column_names(dataset), num_proc=self.data_args.preprocessing_num_workers, desc="Preprocessing", ) @@ -103,6 +104,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: if "input_ids" not in get_column_names(dataset): # tokenize/ process + dataset = self.filter_tokenizer_args(dataset) dataset = self.map( dataset, self.tokenize, @@ -110,7 +112,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: keep_in_memory=True, # bug occurs when not batched and not in memory, # subsequent ds.map calls are always batched, # regardless of `batched` argument - remove_columns=get_column_names(dataset), + remove_columns=get_column_names(dataset), # assumes that input names + # and output names are disjoint num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, desc="Tokenizing", @@ -138,7 +141,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: desc="Adding labels", ) - elif self.PROMPT_KEY in dataset.column_names: + elif self.PROMPT_KEY in get_column_names(dataset): dataset.remove_columns(self.PROMPT_KEY) return dataset @@ -207,6 +210,18 @@ def rename_columns(self, dataset: DatasetType) -> DatasetType: return dataset + def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType: + # assumes that inputs are not passed via self.processor.__call__ args and kwargs + signature = inspect.signature(self.processor.__call__) + tokenizer_args = set( + key + for key, param in signature.parameters.items() + if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD) + ) + + column_names = get_column_names(dataset) + return dataset.remove_columns(list(set(column_names) - set(tokenizer_args))) + def tokenize(self, data: LazyRow) -> Dict[str, Any]: # separate prompt prompt = data.pop(self.PROMPT_KEY, None) From af86f45c479dfa1345f4394b6277dcd6bb43ec7e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 18 Dec 2024 17:31:33 +0000 Subject: [PATCH 217/285] filter_tokenizer_args Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/base.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index b69a7dfcd..c7eb33259 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -1,4 +1,6 @@ +import inspect from functools import cached_property +from inspect import _ParameterKind as Kind from typing import Any, Callable, Dict, List, Union from compressed_tensors.registry import RegistryMixin @@ -93,7 +95,6 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset, self.preprocess, batched=False, - remove_columns=get_column_names(dataset), num_proc=self.data_args.preprocessing_num_workers, desc="Preprocessing", ) @@ -103,6 +104,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: if "input_ids" not in get_column_names(dataset): # tokenize/ process + dataset = self.filter_tokenizer_args(dataset) dataset = self.map( dataset, self.tokenize, @@ -110,7 +112,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: keep_in_memory=True, # bug occurs when not batched and not in memory, # subsequent ds.map calls are always batched, # regardless of `batched` argument - remove_columns=get_column_names(dataset), + remove_columns=get_column_names(dataset), # assumes that input names + # and output names are disjoint num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, desc="Tokenizing", @@ -138,7 +141,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: desc="Adding labels", ) - elif self.PROMPT_KEY in dataset.column_names: + elif self.PROMPT_KEY in get_column_names(dataset): dataset.remove_columns(self.PROMPT_KEY) return dataset @@ -207,6 +210,18 @@ def rename_columns(self, dataset: DatasetType) -> DatasetType: return dataset + def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType: + # assumes that inputs are not passed via self.processor.__call__ args and kwargs + signature = inspect.signature(self.processor.__call__) + tokenizer_args = set( + key + for key, param in signature.parameters.items() + if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD) + ) + + column_names = get_column_names(dataset) + return dataset.remove_columns(list(set(column_names) - set(tokenizer_args))) + def tokenize(self, data: LazyRow) -> Dict[str, Any]: # separate prompt prompt = data.pop(self.PROMPT_KEY, None) From 9b611452add3eae6bf617d8aab5de9ffbc6b0ba7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 18 Dec 2024 20:20:22 +0000 Subject: [PATCH 218/285] update docstring --- .../modifiers/quantization/gptq/base.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index daca8c274..3815134d0 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -43,9 +43,9 @@ class GPTQModifier(Modifier, HooksMixin): | test_stage: | obcq_modifiers: | GPTQModifier: - | true_sequential: False - | dampening_frac: 0.001 | block_size: 128 + | dampening_frac: 0.001 + | offload_hessians: False | config_groups: | group_0: | targets: @@ -63,25 +63,17 @@ class GPTQModifier(Modifier, HooksMixin): :param sequential_targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model - :param block_size: Used to determine number of columns to compress in one pass + :param dampening_frac: Amount of dampening to apply to H, as a fraction of the + diagonal norm :param quantize: Set to True to quantize using an existing quantization modifier, or pass in the configuration for a quantization modifier if one does not already exist in the recipe - :param dampening_frac: Amount of dampening to apply to H, as a fraction of the - diagonal norm + :param offload_hessians: Set to True for decreased memory usage but increased + runtime. :param config_groups: [Used, if a quantization modifier is not specified], dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized. - :param ignore: [Used, if a quantization modifier is not specified] - optional list of module class names or submodule names to not - quantize even if they match a target in config_groups. Defaults to empty list. - :param disable_quantization_observer_epoch: [Used, if a quantization modifier is - not specified] Epoch to disable updates to the module - quantization observers. At this point, quantized weights and zero points will - not be updated. Leave None to not disable observers during QAT. Default is None - :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used :param scheme: [Used, if a quantization modifier is not specified], the quantization scheme to apply to the model, this is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter @@ -89,6 +81,17 @@ class GPTQModifier(Modifier, HooksMixin): `preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit or a string of a preset scheme if targets is provided and activation 8 bit quantization on the Linear layers. + :param targets: list of layer names to quantize if a scheme is provided. Defaults + to Linear layers + :param ignore: [Used, if a quantization modifier is not specified] + optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. + :param num_calibration_steps: Number of steps to run post training calibration for. + When None, the entire calibration_dataloader is used + :param disable_quantization_observer_epoch: [Used, if a quantization modifier is + not specified] Epoch to disable updates to the module + quantization observers. At this point, quantized weights and zero points will + not be updated. Leave None to not disable observers during QAT. Default is None """ # gptq modifier arguments From 2f65d01a818f619dd9bc7c22c2aa90093402f5ce Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 18 Dec 2024 20:22:18 +0000 Subject: [PATCH 219/285] remove unused util Signed-off-by: Kyle Sayers --- .../finetune/data/data_helpers.py | 40 ------------------- 1 file changed, 40 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 175d0a5f5..693f4e3dc 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -5,7 +5,6 @@ import torch from datasets import Dataset, load_dataset -from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator @@ -15,7 +14,6 @@ LABELS_MASK_VALUE = -100 __all__ = [ - "create_batch_dataloader", "format_calibration_data", "get_raw_dataset", "make_dataset_splits", @@ -24,44 +22,6 @@ ] -def create_batch_dataloader( - dataloader: torch.utils.data.DataLoader, - batch_size: int, -) -> torch.utils.data.DataLoader: - """ - Create a dataloader whose batch size is equal to the size of the dataset - - :param dataset: dataset used to generate dataloader - :param batch_size: batch size of new dataloader - :return: dataloader - """ - dataset = dataloader.dataset - sampler = dataloader.sampler.__class__(dataset) - - def pad_sequences(batch): - # extract input_ids and attention_mask from the batch - input_ids = [torch.tensor(item["input_ids"]).squeeze(0) for item in batch] - masks = [torch.tensor(item["attention_mask"]).squeeze(0) for item in batch] - - # while 0 is not necessarily the "correct" padding value, the padded - # input_ids are ignored according to the attention_mask - pad_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) - pad_masks = pad_sequence(masks, batch_first=True, padding_value=0) - - return { - "input_ids": pad_input_ids, - "attention_mask": pad_masks, - } - - return torch.utils.data.DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - collate_fn=pad_sequences, - pin_memory=True, - ) - - def format_calibration_data( tokenized_dataset: Dataset, num_calibration_samples: Optional[int] = None, From 338d1cb23651ec9d454b52208a7772591e0a89bb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 18 Dec 2024 20:25:22 +0000 Subject: [PATCH 220/285] remove debug --- src/llmcompressor/transformers/finetune/data/data_helpers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 693f4e3dc..74aafce99 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -67,10 +67,6 @@ def format_calibration_data( if accelerator: calib_dataloader = accelerator.prepare(calib_dataloader) - # sample = next(iter(calib_dataloader)) - # print({k: [torch.tensor(s).shape for s in sample[k]] for k in sample}) - # breakpoint() - return calib_dataloader From f4fa9c37db6ac789b19928abe496551dbefc46d2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 18 Dec 2024 20:25:47 +0000 Subject: [PATCH 221/285] more tests Signed-off-by: Kyle Sayers --- .../finetune/data/test_dataset_helpers.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 5229ea735..5e4f71d7b 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -72,7 +72,18 @@ def test_datasets_fallbacks(): assert split_ds.get("test").ds_name == "test_ds" assert split_ds.get("calibration").ds_name == "test_ds" - # oneshot will take any dataset + # oneshot takes train without warning + mock_datasets = {"train": Mock(ds_name="train_ds", column_names=[])} + split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) + assert split_ds.get("calibration").ds_name == "train_ds" + + # oneshot takes test with warning + mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])} + with pytest.warns(UserWarning): + split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) + assert split_ds.get("calibration").ds_name == "test_ds" + + # oneshot takes custom splits with warning mock_datasets = {"custom_split": Mock(ds_name="custom_ds", column_names=[])} with pytest.warns(UserWarning): split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) From e757e61705ec73d0b9a3e0ad043ba3874fe8e31d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 18 Dec 2024 20:28:59 +0000 Subject: [PATCH 222/285] remove duplicate file Signed-off-by: Kyle Sayers --- src/llmcompressor/utils/typing.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 src/llmcompressor/utils/typing.py diff --git a/src/llmcompressor/utils/typing.py b/src/llmcompressor/utils/typing.py deleted file mode 100644 index 1050f7138..000000000 --- a/src/llmcompressor/utils/typing.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Union - -from datasets import Dataset, DatasetDict, IterableDataset -from transformers import ( - BaseImageProcessor, - FeatureExtractionMixin, - PreTrainedTokenizer, - ProcessorMixin, -) - -# Tokenizer or Processor. Processors do not inherit from a unified base class -Processor = Union[ - PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin -] - -# Supported dataset types, IterableDataset is a streamed dataset -DatasetType = Union[Dataset, DatasetDict, IterableDataset] From bdfa3d4e39847ae4f5b3a4b5f50037c0a5c6e008 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 18 Dec 2024 21:09:20 +0000 Subject: [PATCH 223/285] better help texts Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/data_args.py | 16 +++++++++------- .../transformers/finetune/text_generation.py | 4 ++-- .../finetune/data/test_dataset_loading.py | 1 - 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index a02975ef8..d0c097978 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -35,23 +35,25 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): text_column: str = field( default="text", - metadata={"help": "For custom datasets only. The text field key"}, + metadata={ + "help": ( + "Optional key to be used as the `text` input to tokenizer/processor " + "after dataset preprocesssing" + ) + }, ) remove_columns: Union[None, str, List] = field( default=None, - metadata={ - "help": "This argument is depreciated. Column names to remove after " - "preprocessing custom datasets" - }, + metadata={"help": "Column names to remove after preprocessing (depreciated)"}, ) preprocessing_func: Union[None, str, Callable] = field( default=None, metadata={ "help": ( - "For custom datasets only. Either a function to apply to the dataset, " - "a function name defined in " + "Typically a function which applies a chat template. Can take the form " + "of iither a function to apply to the dataset, a name defined in " "src/llmcompressor/transformers/utils/preprocessing_functions.py, or " "a path to a function definition of the form /path/to/file.py:func" ) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 1b0de6870..2c596f557 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -141,8 +141,8 @@ def parse_args(**kwargs): # raise depreciation warnings if data_args.remove_columns is not None: warnings.warn( - "`remove_columns` argument is depreciated, when processing non-tokenized " - "datasets, all columns not returned by preprocessing_fn will be removed", + "`remove_columns` argument is depreciated. When tokenizing datasets, all " + "columns which are invalid inputs the tokenizer will be removed", DeprecationWarning, ) diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index ef60cb811..64514b252 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -15,7 +15,6 @@ format_calibration_data, ) from llmcompressor.transformers.finetune.runner import StageRunner -from llmcompressor.transformers.finetune.training_args import TrainingArguments @pytest.mark.unit From f1e133545d1c64e1fd73894edc7881072c67edec Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 20:08:09 +0000 Subject: [PATCH 224/285] remove future notes, todos --- src/llmcompressor/pipelines/basic/pipeline.py | 15 ++++----------- .../pipelines/layer_sequential/pipeline.py | 5 +---- .../pipelines/sequential/pipeline.py | 5 +---- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 31f4c01be..bb2e9a352 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -5,21 +5,14 @@ from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context +from compressed_tensors.utils import get_execution_device __all__ = ["run_pipeline"] -def run_pipeline( - model: torch.nn.Module, - dataloader: torch.utils.data.DataLoader, -): - # TODO: revisit - device_map = getattr(model, "hf_device_map", None) - if device_map is not None: - model_device = next(iter(device_map.values())) - else: - model_device = model.device - +def run_pipeline(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + model_device = get_execution_device(model) + with calibration_forward_context(model): for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index b659b93b9..334269e83 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -18,7 +18,7 @@ def run_pipeline( model: torch.nn.Module, - sequential_targets: List[str], # FUTURE: replace with recipe inference + sequential_targets: List[str], dataloader: torch.utils.data.DataLoader, propagate_error: bool, ): @@ -40,9 +40,6 @@ def run_pipeline( # find layers layers = match_modules(model, sequential_targets) - # FUTURE: apply recipe to model - # initialize(recipe, model) - with calibration_forward_context(model): intermediates = capture_first_layer_intermediates(model, layers, dataloader) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index b1be72ddb..629c0aeb5 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -16,7 +16,7 @@ def run_pipeline( model: torch.nn.Module, - sequential_targets: List[str], # FUTURE: replace with recipe inference + sequential_targets: List[str], ignore: List[str], dataloader: torch.utils.data.DataLoader, propagate_error: bool, @@ -43,9 +43,6 @@ def run_pipeline( sample_input = next(iter(dataloader)) subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) - # FUTURE: apply recipe to model - # initialize(recipe, model) - with calibration_forward_context(model): # prepare intermediates cache model_device = get_execution_device(model) From e59c2e75c01c555d257786c4dd411c764106ed11 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 20:16:51 +0000 Subject: [PATCH 225/285] remove skipping patching --- .../transformers/sparsification/compressed_tensors_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 766b7e584..ce4ae7fb2 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -229,7 +229,6 @@ def patch_tied_tensors_bug(model: torch.nn.Module): :param model: model to fix """ - return if ( hasattr(model.config, "tie_word_embeddings") and not model.config.tie_word_embeddings From 4932ec592e779831390f7764960ebdc2c2e448d9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 20:21:36 +0000 Subject: [PATCH 226/285] remove skipping for none args --- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 7de7dc58d..b3ec829e8 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -1,5 +1,5 @@ from itertools import cycle -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Any import torch from torch.nn import Module @@ -25,9 +25,7 @@ class EarlyStopException(Exception): :param kwargs: keyword inputs passed to the layer where the excetion was raised """ - def __init__(self, args: Tuple, kwargs: Dict): - if args is None: - return + def __init__(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]): self.args = tensors_to_device(args, "cpu") self.kwargs = kwargs From 6b7c11f952f9c010afe5965ec94f59b77e703809 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 21:26:04 +0000 Subject: [PATCH 227/285] revert data split fallbacks --- .../finetune/data/data_helpers.py | 90 +++++-------------- .../finetune/data/test_dataset_helpers.py | 41 +-------- 2 files changed, 25 insertions(+), 106 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 74aafce99..23c70e561 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,6 +1,5 @@ import logging import os -import warnings from typing import Any, Callable, Dict, List, Optional import torch @@ -8,8 +7,6 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator -from llmcompressor.typing import DatasetType - LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 @@ -18,7 +15,6 @@ "get_raw_dataset", "make_dataset_splits", "get_custom_datasets_from_path", - "LABELS_MASK_VALUE", ] @@ -28,7 +24,7 @@ def format_calibration_data( do_shuffle: bool = True, collate_fn: Callable = default_data_collator, accelerator: Optional[Any] = None, -) -> torch.utils.data.DataLoader: +) -> List[torch.Tensor]: """ Creates a dataloader out of the calibration dataset split, trimming it to the desired number of calibration samples @@ -96,17 +92,17 @@ def get_raw_dataset( def make_dataset_splits( - datasets: Dict[str, DatasetType], + tokenized_datasets: Dict[str, Any], do_train: bool = False, do_eval: bool = False, do_predict: bool = False, do_oneshot: bool = False, -) -> Dict[str, DatasetType]: +) -> Dict[str, Dataset]: """ Restructures the datasets dictionary based on what tasks will be run (train, eval, predict) - :param datasets: dictionary of processed datasets + :param tokenized_datasets: dictionary of processed datasets :param do_train: Whether to store the train dataset :param do_eval: Whether to store the validation dataset :param do_predict: Whether to store the test dataset @@ -115,40 +111,31 @@ def make_dataset_splits( """ # handles case where all splits are contained in a single dataset - if "all" in datasets and len(datasets) == 1: - datasets = datasets.get("all") - if isinstance(datasets, Dataset): - datasets = {"train": datasets, "calibration": datasets} # shallow copy + if "all" in tokenized_datasets and len(tokenized_datasets) == 1: + tokenized_datasets = tokenized_datasets.get("all") + if isinstance(tokenized_datasets, Dataset): + tokenized_datasets = {"train": tokenized_datasets} train_split = eval_split = predict_split = calib_split = None if do_train: - train_split = _get_split_with_fallbacks( - datasets, "train", ["train"], strict=True - ) + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_split = tokenized_datasets["train"] if do_eval: - eval_split = _get_split_with_fallbacks( - datasets, "evaluation", ["validation"], ["test"], strict=True - ) + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_split = tokenized_datasets["validation"] if do_predict: - predict_split = _get_split_with_fallbacks( - datasets, "prediction", ["test"], ["validation"], strict=True - ) + if "test" not in tokenized_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_split = tokenized_datasets["test"] if do_oneshot: - calib_split = _get_split_with_fallbacks( - datasets, - "oneshot", - ["calibration", "train"], - ["test", "validation"], - strict=False, - ) - - # remove labels from calibration dataset - column_names = calib_split.column_names - if isinstance(column_names, dict): - column_names = sum(column_names.values(), []) - if "labels" in column_names: - calib_split = calib_split.remove_columns("labels") + calib_split = tokenized_datasets.get("calibration") + if calib_split is None: + if "train" not in tokenized_datasets: + raise ValueError("--do_oneshot requires a calibration dataset") + calib_split = tokenized_datasets["train"] split_datasets = { "train": train_split, @@ -256,36 +243,3 @@ def do_transform(candidate: str) -> bool: transform_dataset_key(dataset_key) return data_files - - -def _get_split_with_fallbacks( - datasets: Dict[str, DatasetType], - task: str, - preferred: List[str], - fallbacks: List[str] = [], - strict: bool = True, -) -> DatasetType: - assert len(preferred) > 0 - if len(datasets) <= 0: - raise ValueError("Cannot get retrieve data from dataset with no splits") - - # check preferred names (without warning) - for pref in preferred: - if pref in datasets: - return datasets[pref] - - # fallback to the first available dataset if all else fails - if not strict: - fallbacks.append(next(iter(datasets.keys()))) - - # check fallbacks (with warning) - for fallback in fallbacks: - if fallback in datasets: - warnings.warn( - f"{task} expects one of {preferred} dataset split, " - f"falling back to {fallback}. Use " - f'`splits={{"{preferred[0]}": "{fallback}"}}` to silence this warning' - ) - return datasets[fallback] - - raise ValueError(f"{task} expects at least one of {fallbacks} dataset splits") diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 5e4f71d7b..812b26a56 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,5 +1,3 @@ -from unittest.mock import Mock - import pytest from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments @@ -50,41 +48,8 @@ def test_separate_datasets(): assert split_datasets.get("validation") is not None assert split_datasets.get("test") is None - -@pytest.mark.unit -def test_datasets_fallbacks(): - # strict splits - mock_datasets = {"calibration": Mock(ds_name="calibration_ds", column_names=[])} - with pytest.raises(ValueError): - _ = make_dataset_splits(mock_datasets, do_train=True) - with pytest.raises(ValueError): - _ = make_dataset_splits(mock_datasets, do_eval=True) with pytest.raises(ValueError): - _ = make_dataset_splits(mock_datasets, do_predict=True) - - # validation, predict, and oneshot fallbacks - mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])} - with pytest.warns(UserWarning): - split_ds = make_dataset_splits( - mock_datasets, do_eval=True, do_predict=True, do_oneshot=True + # fails due to no test split specified + split_datasets = make_dataset_splits( + datasets, do_train=True, do_eval=True, do_predict=True ) - assert split_ds.get("validation").ds_name == "test_ds" - assert split_ds.get("test").ds_name == "test_ds" - assert split_ds.get("calibration").ds_name == "test_ds" - - # oneshot takes train without warning - mock_datasets = {"train": Mock(ds_name="train_ds", column_names=[])} - split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) - assert split_ds.get("calibration").ds_name == "train_ds" - - # oneshot takes test with warning - mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])} - with pytest.warns(UserWarning): - split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) - assert split_ds.get("calibration").ds_name == "test_ds" - - # oneshot takes custom splits with warning - mock_datasets = {"custom_split": Mock(ds_name="custom_ds", column_names=[])} - with pytest.warns(UserWarning): - split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) - assert split_ds.get("calibration").ds_name == "custom_ds" From 601cb0e07fbf1581bddba75d565d51d3a74f6811 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 21:26:46 +0000 Subject: [PATCH 228/285] rvert data split fallbacks --- .../finetune/data/data_helpers.py | 84 +++++-------------- .../finetune/data/test_dataset_helpers.py | 41 +-------- 2 files changed, 24 insertions(+), 101 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 148cf85af..23c70e561 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,6 +1,5 @@ import logging import os -import warnings from typing import Any, Callable, Dict, List, Optional import torch @@ -8,8 +7,6 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator -from llmcompressor.typing import DatasetType - LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 @@ -95,17 +92,17 @@ def get_raw_dataset( def make_dataset_splits( - datasets: Dict[str, DatasetType], + tokenized_datasets: Dict[str, Any], do_train: bool = False, do_eval: bool = False, do_predict: bool = False, do_oneshot: bool = False, -) -> Dict[str, DatasetType]: +) -> Dict[str, Dataset]: """ Restructures the datasets dictionary based on what tasks will be run (train, eval, predict) - :param datasets: dictionary of processed datasets + :param tokenized_datasets: dictionary of processed datasets :param do_train: Whether to store the train dataset :param do_eval: Whether to store the validation dataset :param do_predict: Whether to store the test dataset @@ -114,39 +111,31 @@ def make_dataset_splits( """ # handles case where all splits are contained in a single dataset - if "all" in datasets and len(datasets) == 1: - datasets = datasets.get("all") - if isinstance(datasets, Dataset): - datasets = {"train": datasets} + if "all" in tokenized_datasets and len(tokenized_datasets) == 1: + tokenized_datasets = tokenized_datasets.get("all") + if isinstance(tokenized_datasets, Dataset): + tokenized_datasets = {"train": tokenized_datasets} train_split = eval_split = predict_split = calib_split = None if do_train: - train_split = _get_split_with_fallbacks( - datasets, "train", ["train"], strict=True - ) + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_split = tokenized_datasets["train"] if do_eval: - eval_split = _get_split_with_fallbacks( - datasets, "evaluation", ["validation", "test"], strict=True - ) + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_split = tokenized_datasets["validation"] if do_predict: - predict_split = _get_split_with_fallbacks( - datasets, "prediction", ["test", "validation"], strict=True - ) + if "test" not in tokenized_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_split = tokenized_datasets["test"] if do_oneshot: - calib_split = _get_split_with_fallbacks( - datasets, - "oneshot", - ["calibration", "train", "test", "validation"], - strict=False, - ) - - # remove labels from calibration dataset - column_names = calib_split.column_names - if isinstance(column_names, dict): - column_names = sum(column_names.values(), []) - if "labels" in column_names: - calib_split = calib_split.remove_columns("labels") + calib_split = tokenized_datasets.get("calibration") + if calib_split is None: + if "train" not in tokenized_datasets: + raise ValueError("--do_oneshot requires a calibration dataset") + calib_split = tokenized_datasets["train"] split_datasets = { "train": train_split, @@ -254,34 +243,3 @@ def do_transform(candidate: str) -> bool: transform_dataset_key(dataset_key) return data_files - - -def _get_split_with_fallbacks( - datasets: Dict[str, DatasetType], - task: str, - fallbacks: List[str], - strict: bool = True, -) -> DatasetType: - assert len(fallbacks) > 0 - if len(datasets) <= 0: - raise ValueError("Cannot get retrieve data from dataset with no splits") - - # check first choice - first_choice = fallbacks[0] - if first_choice in datasets: - return datasets[first_choice] - - # last fallback is first available split - if not strict: - fallbacks.append(next(iter(datasets.keys()))) - - # check fallbacks - for fallback in fallbacks[1:]: - if fallback in datasets: - warnings.warn( - f"{task} expects a {first_choice} dataset split, " - f"falling back to {fallback}" - ) - return datasets[fallback] - - raise ValueError(f"{task} expects at least one of {fallbacks} dataset splits") diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 5e4f71d7b..812b26a56 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,5 +1,3 @@ -from unittest.mock import Mock - import pytest from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments @@ -50,41 +48,8 @@ def test_separate_datasets(): assert split_datasets.get("validation") is not None assert split_datasets.get("test") is None - -@pytest.mark.unit -def test_datasets_fallbacks(): - # strict splits - mock_datasets = {"calibration": Mock(ds_name="calibration_ds", column_names=[])} - with pytest.raises(ValueError): - _ = make_dataset_splits(mock_datasets, do_train=True) - with pytest.raises(ValueError): - _ = make_dataset_splits(mock_datasets, do_eval=True) with pytest.raises(ValueError): - _ = make_dataset_splits(mock_datasets, do_predict=True) - - # validation, predict, and oneshot fallbacks - mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])} - with pytest.warns(UserWarning): - split_ds = make_dataset_splits( - mock_datasets, do_eval=True, do_predict=True, do_oneshot=True + # fails due to no test split specified + split_datasets = make_dataset_splits( + datasets, do_train=True, do_eval=True, do_predict=True ) - assert split_ds.get("validation").ds_name == "test_ds" - assert split_ds.get("test").ds_name == "test_ds" - assert split_ds.get("calibration").ds_name == "test_ds" - - # oneshot takes train without warning - mock_datasets = {"train": Mock(ds_name="train_ds", column_names=[])} - split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) - assert split_ds.get("calibration").ds_name == "train_ds" - - # oneshot takes test with warning - mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])} - with pytest.warns(UserWarning): - split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) - assert split_ds.get("calibration").ds_name == "test_ds" - - # oneshot takes custom splits with warning - mock_datasets = {"custom_split": Mock(ds_name="custom_ds", column_names=[])} - with pytest.warns(UserWarning): - split_ds = make_dataset_splits(mock_datasets, do_oneshot=True) - assert split_ds.get("calibration").ds_name == "custom_ds" From 4123636b836ce9e93c1f3e1be5d6afe9c5f0ade9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 21:50:41 +0000 Subject: [PATCH 229/285] propagate oom errors, separate data collators Signed-off-by: Kyle Sayers --- examples/multimodal_vision/mllama.py | 21 ++----------- examples/multimodal_vision/pixtral.py | 17 ++-------- examples/multimodal_vision/pixtral_large.py | 6 ---- .../{qwen_vl2.py => qwen2_vl.py} | 21 +++---------- .../modifiers/quantization/gptq/base.py | 8 +++-- .../transformers/utils/data_collator.py | 31 +++++++++++++++++++ 6 files changed, 47 insertions(+), 57 deletions(-) delete mode 100644 examples/multimodal_vision/pixtral_large.py rename examples/multimodal_vision/{qwen_vl2.py => qwen2_vl.py} (79%) create mode 100644 src/llmcompressor/transformers/utils/data_collator.py diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 3f61532e2..6a6bf0d7e 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -5,9 +5,9 @@ from llmcompressor.modifiers.quantization import GPTQModifier -# from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TracableMllamaForConditionalGeneration +from llmcompressor.transformers.utils.data_collator import mllama_data_collator # Load model. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" @@ -18,24 +18,10 @@ # Oneshot arguments DATASET_ID = "flickr30k" -DATASET_SPLIT = "test[:512]" +DATASET_SPLIT = {"calibration": "test[:512]"} NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 - -# TODO: define real collators in utils -def data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"]), - "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), - "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), - "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), - } - - # Recipe recipe = [ # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), @@ -60,8 +46,7 @@ def data_collator(batch): num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, output_dir=save_path, - data_collator=data_collator, - # data_collator=DataCollator(), + data_collator=mllama_data_collator, ) # Confirm generations of the quantized model look sane. diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index f311582a6..34890dee3 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -8,6 +8,7 @@ # from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration +from llmcompressor.transformers.utils.data_collator import pixtral_data_collator # Load model. model_id = "mgoin/pixtral-12b" @@ -18,21 +19,10 @@ # Oneshot arguments DATASET_ID = "flickr30k" -DATASET_SPLIT = "test[:512]" +DATASET_SPLIT = {"calibration": "test[:512]"} NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 - -# TODO: define real collators in utils -def data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"])[0], - } - - # Recipe recipe = [ # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), @@ -58,8 +48,7 @@ def data_collator(batch): num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, output_dir=save_path, - data_collator=data_collator, - # data_collator=DataCollator(), + data_collator=pixtral_data_collator, ) # Confirm generations of the quantized model look sane. diff --git a/examples/multimodal_vision/pixtral_large.py b/examples/multimodal_vision/pixtral_large.py deleted file mode 100644 index ebef5047d..000000000 --- a/examples/multimodal_vision/pixtral_large.py +++ /dev/null @@ -1,6 +0,0 @@ -from transformers import AutoProcessor - -processor = AutoProcessor.from_pretrained( - "mistral-community/Pixtral-Large-Instruct-2411" -) -processor = AutoProcessor.from_pretrained("mgoin/pixtral-12b") diff --git a/examples/multimodal_vision/qwen_vl2.py b/examples/multimodal_vision/qwen2_vl.py similarity index 79% rename from examples/multimodal_vision/qwen_vl2.py rename to examples/multimodal_vision/qwen2_vl.py index 5e470bdc3..b57de913e 100644 --- a/examples/multimodal_vision/qwen_vl2.py +++ b/examples/multimodal_vision/qwen2_vl.py @@ -11,6 +11,7 @@ from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot +from llmcompressor.transformers.utils.data_collator import qwen2_vl_data_collator # Load model. model_id = "Qwen/Qwen2-VL-2B-Instruct" @@ -21,24 +22,10 @@ # Oneshot arguments DATASET_ID = "flickr30k" -DATASET_SPLIT = "test[:3]" -NUM_CALIBRATION_SAMPLES = 1 +DATASET_SPLIT = {"calibration": "test[:512]"} +NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 - -# TODO: define real collators in utils -def data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor( - batch[0]["pixel_values"] - ), # torch.Size([14308, 1176]) - "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), - } - - # Recipe recipe = GPTQModifier( targets="Linear", @@ -75,7 +62,7 @@ def data_collator(batch): num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, output_dir=save_path, - data_collator=data_collator, + data_collator=qwen2_vl_data_collator, ) processor.save_pretrained(save_path) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 3815134d0..d01584aa6 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -216,7 +216,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer pipeline model_name = state.model.__class__.__name__ - column_names = state.data.calib.dataset.column_names + input_names = state.data.calib.dataset.column_names try: run_sequential( state.model, @@ -229,7 +229,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: except Exception as exception: if isinstance(exception, torch.fx.proxy.TraceError): - warnings.warn(f"Failed to trace {model_name} with {column_names}") + warnings.warn(f"Failed to trace {model_name} with inputs {input_names}") + if isinstance(exception, torch.OutOfMemoryError): + raise exception warnings.warn("Falling back to layer_sequential pipeline") try: @@ -244,6 +246,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: except Exception as exception: if isinstance(exception, TypeError): warnings.warn(f"{model_name} fails layer-wise assumptions") + if isinstance(exception, torch.OutOfMemoryError): + raise exception warnings.warn( "Falling back to basic pipeline, which requires extra memory and " diff --git a/src/llmcompressor/transformers/utils/data_collator.py b/src/llmcompressor/transformers/utils/data_collator.py new file mode 100644 index 000000000..058ce8af1 --- /dev/null +++ b/src/llmcompressor/transformers/utils/data_collator.py @@ -0,0 +1,31 @@ +import torch + +__all__ = ["mllama_data_collator", "pixtral_data_collator"] + +def mllama_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]), + "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]), + "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), + } + +def pixtral_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"])[0], + } + +def qwen2_vl_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), + } \ No newline at end of file From c1e66e885d9eaa65f12feb529c2d35800584f3ad Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 21:54:06 +0000 Subject: [PATCH 230/285] apply style, ignore visual on qwen --- examples/multimodal_vision/mllama.py | 2 -- examples/multimodal_vision/pixtral.py | 1 - examples/multimodal_vision/qwen2_vl.py | 3 +-- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 2 +- src/llmcompressor/pipelines/basic/pipeline.py | 4 ++-- src/llmcompressor/transformers/utils/data_collator.py | 5 ++++- 6 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index 6a6bf0d7e..b9241f53a 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -1,10 +1,8 @@ import os -import torch from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier - from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TracableMllamaForConditionalGeneration from llmcompressor.transformers.utils.data_collator import mllama_data_collator diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 34890dee3..4687e150b 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -1,6 +1,5 @@ import os -import torch from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier diff --git a/examples/multimodal_vision/qwen2_vl.py b/examples/multimodal_vision/qwen2_vl.py index b57de913e..2b27e81b3 100644 --- a/examples/multimodal_vision/qwen2_vl.py +++ b/examples/multimodal_vision/qwen2_vl.py @@ -1,6 +1,5 @@ import os -import torch from compressed_tensors.quantization import ( QuantizationArgs, QuantizationScheme, @@ -43,7 +42,7 @@ ), ), }, - ignore=["re:.*lm_head"], + ignore=["re:visual.*", "re:.*lm_head"], dampening_frac=0.5, ) diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index b3ec829e8..00165c98a 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -1,5 +1,5 @@ from itertools import cycle -from typing import Callable, Dict, List, Optional, Tuple, Any +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch.nn import Module diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index bb2e9a352..142698967 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,18 +1,18 @@ import torch import torch.utils.data.dataloader import tqdm +from compressed_tensors.utils import get_execution_device from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context -from compressed_tensors.utils import get_execution_device __all__ = ["run_pipeline"] def run_pipeline(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): model_device = get_execution_device(model) - + with calibration_forward_context(model): for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/transformers/utils/data_collator.py b/src/llmcompressor/transformers/utils/data_collator.py index 058ce8af1..454ada61e 100644 --- a/src/llmcompressor/transformers/utils/data_collator.py +++ b/src/llmcompressor/transformers/utils/data_collator.py @@ -2,6 +2,7 @@ __all__ = ["mllama_data_collator", "pixtral_data_collator"] + def mllama_data_collator(batch): assert len(batch) == 1 return { @@ -13,6 +14,7 @@ def mllama_data_collator(batch): "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]), } + def pixtral_data_collator(batch): assert len(batch) == 1 return { @@ -21,6 +23,7 @@ def pixtral_data_collator(batch): "pixel_values": torch.tensor(batch[0]["pixel_values"])[0], } + def qwen2_vl_data_collator(batch): assert len(batch) == 1 return { @@ -28,4 +31,4 @@ def qwen2_vl_data_collator(batch): "attention_mask": torch.tensor(batch[0]["attention_mask"]), "pixel_values": torch.tensor(batch[0]["pixel_values"]), "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), - } \ No newline at end of file + } From dc14e95d5b64e7eb560fb4b5240bcf3671cd52f5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 21:55:58 +0000 Subject: [PATCH 231/285] remove qwen while unsupported --- examples/multimodal_vision/qwen2_vl.py | 74 -------------------------- 1 file changed, 74 deletions(-) delete mode 100644 examples/multimodal_vision/qwen2_vl.py diff --git a/examples/multimodal_vision/qwen2_vl.py b/examples/multimodal_vision/qwen2_vl.py deleted file mode 100644 index 2b27e81b3..000000000 --- a/examples/multimodal_vision/qwen2_vl.py +++ /dev/null @@ -1,74 +0,0 @@ -import os - -from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationScheme, - QuantizationStrategy, - QuantizationType, -) -from transformers import AutoProcessor, Qwen2VLForConditionalGeneration - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot -from llmcompressor.transformers.utils.data_collator import qwen2_vl_data_collator - -# Load model. -model_id = "Qwen/Qwen2-VL-2B-Instruct" -model = Qwen2VLForConditionalGeneration.from_pretrained( - model_id, device_map="auto", torch_dtype="auto" -) -processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - -# Oneshot arguments -DATASET_ID = "flickr30k" -DATASET_SPLIT = {"calibration": "test[:512]"} -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - -# Recipe -recipe = GPTQModifier( - targets="Linear", - config_groups={ - "config_group": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - type=QuantizationType.INT, - strategy=QuantizationStrategy.GROUP, - group_size=128, - symmetric=True, - dynamic=False, - actorder="dynamic", - ), - ), - }, - ignore=["re:visual.*", "re:.*lm_head"], - dampening_frac=0.5, -) - -# Perform oneshot -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") -oneshot( - model=model, - tokenizer=model_id, - # dataset=ds, - dataset=DATASET_ID, - splits=DATASET_SPLIT, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - output_dir=save_path, - data_collator=qwen2_vl_data_collator, -) - -processor.save_pretrained(save_path) - -# Confirm generations of the quantized model look sane. -print("========== SAMPLE GENERATION ==============") -input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=20) -print(processor.decode(output[0])) -print("==========================================") From 47249c564f92fa5aa2c0de83486ba0a9331f543b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 21:58:55 +0000 Subject: [PATCH 232/285] remove smoothquant while unsupported --- examples/multimodal_vision/mllama.py | 1 - examples/multimodal_vision/pixtral.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index b9241f53a..e87e42c16 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -22,7 +22,6 @@ # Recipe recipe = [ - # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), GPTQModifier( targets="Linear", scheme="W8A8", diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 4687e150b..6a37f7250 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -24,7 +24,6 @@ # Recipe recipe = [ - # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), GPTQModifier( targets="Linear", scheme="W8A8", From de40a841e619a053c4186c47894377aea34c305e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 19 Dec 2024 17:06:31 -0500 Subject: [PATCH 233/285] clean up examples --- examples/multimodal_vision/mllama.py | 11 +++++------ examples/multimodal_vision/pixtral.py | 13 +++++-------- .../llmcompressor/modifiers/calibration/__init__.py | 0 3 files changed, 10 insertions(+), 14 deletions(-) create mode 100644 tests/llmcompressor/modifiers/calibration/__init__.py diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama.py index e87e42c16..20fe316be 100644 --- a/examples/multimodal_vision/mllama.py +++ b/examples/multimodal_vision/mllama.py @@ -1,5 +1,3 @@ -import os - from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier @@ -30,9 +28,6 @@ ] # Perform oneshot -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") oneshot( model=model, tokenizer=model_id, @@ -42,7 +37,6 @@ max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, - output_dir=save_path, data_collator=mllama_data_collator, ) @@ -52,3 +46,8 @@ output = model.generate(input_ids, max_new_tokens=20) print(processor.decode(output[0])) print("==========================================") + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral.py index 6a37f7250..d9a161d01 100644 --- a/examples/multimodal_vision/pixtral.py +++ b/examples/multimodal_vision/pixtral.py @@ -1,10 +1,6 @@ -import os - from transformers import AutoProcessor from llmcompressor.modifiers.quantization import GPTQModifier - -# from llmcompressor.pytorch.data_collator import DataCollator from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration from llmcompressor.transformers.utils.data_collator import pixtral_data_collator @@ -33,9 +29,6 @@ ] # Perform oneshot -save_name = model_id.split("/")[1] + "-W8A8" -save_path = os.path.join("./my_test/", save_name) -print("Starting quantization") oneshot( model=model, tokenizer=model_id, @@ -45,7 +38,6 @@ max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, - output_dir=save_path, data_collator=pixtral_data_collator, ) @@ -55,3 +47,8 @@ output = model.generate(input_ids, max_new_tokens=20) print(processor.decode(output[0])) print("==========================================") + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/tests/llmcompressor/modifiers/calibration/__init__.py b/tests/llmcompressor/modifiers/calibration/__init__.py new file mode 100644 index 000000000..e69de29bb From 7f6e8cdcd767db0b58d93e68e0f52590a20143f8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Dec 2024 18:12:29 +0000 Subject: [PATCH 234/285] handle non-fast tokenizers --- .../transformers/finetune/text_generation.py | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 2c596f557..5a06b302f 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -250,14 +250,27 @@ def initialize_processor_from_path( ) -> Processor: processor_src = model_args.processor processor_src = processor_src or get_shared_processor_src(model, teacher) - processor = AutoProcessor.from_pretrained( - processor_src, - cache_dir=model_args.cache_dir, - use_fast=True, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - trust_remote_code=model_args.trust_remote_code_model, - ) + # The use_fast=True option is not currently supported safely in Transformers + # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 + try: + processor = AutoProcessor.from_pretrained( + processor_src, + cache_dir=model_args.cache_dir, + use_fast=True, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=model_args.trust_remote_code_model, + ) + except Exception: + logger.debug("Could not load fast processor, loading slow processor instead") + processor = AutoProcessor.from_pretrained( + processor_src, + cache_dir=model_args.cache_dir, + use_fast=False, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=model_args.trust_remote_code_model, + ) return processor From 1c8afe436026bf261e4979444aec32dd40ac1717 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Dec 2024 18:12:29 +0000 Subject: [PATCH 235/285] handle non-fast tokenizers --- .../transformers/finetune/text_generation.py | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 2c596f557..5a06b302f 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -250,14 +250,27 @@ def initialize_processor_from_path( ) -> Processor: processor_src = model_args.processor processor_src = processor_src or get_shared_processor_src(model, teacher) - processor = AutoProcessor.from_pretrained( - processor_src, - cache_dir=model_args.cache_dir, - use_fast=True, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - trust_remote_code=model_args.trust_remote_code_model, - ) + # The use_fast=True option is not currently supported safely in Transformers + # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 + try: + processor = AutoProcessor.from_pretrained( + processor_src, + cache_dir=model_args.cache_dir, + use_fast=True, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=model_args.trust_remote_code_model, + ) + except Exception: + logger.debug("Could not load fast processor, loading slow processor instead") + processor = AutoProcessor.from_pretrained( + processor_src, + cache_dir=model_args.cache_dir, + use_fast=False, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=model_args.trust_remote_code_model, + ) return processor From 3a9816cfca58735b91227bb42916b2ea352f0a51 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Dec 2024 21:12:04 +0000 Subject: [PATCH 236/285] address nits, add logging Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/base.py | 30 ++++++++++++++----- .../transformers/finetune/data/data_args.py | 4 +-- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index c7eb33259..08ea83512 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -88,6 +88,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: if isinstance(dataset, str): # load dataset: load from huggingface or disk dataset = self.load_dataset() + logger.debug(f"Raw dataset: {get_columns(dataset)}") if self.preprocess is not None: # preprocess: apply template or preprocessing function @@ -98,13 +99,16 @@ def __call__(self, add_labels: bool = True) -> DatasetType: num_proc=self.data_args.preprocessing_num_workers, desc="Preprocessing", ) + logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}") # rename and remove columns match processor kwargs dataset = self.rename_columns(dataset) + logger.debug(f"Dataset after column renaming: {get_columns(dataset)}") - if "input_ids" not in get_column_names(dataset): + if "input_ids" not in get_columns(dataset): # tokenize/ process dataset = self.filter_tokenizer_args(dataset) + logger.debug(f"Tokenizer args after filtering: {get_columns(dataset)}") dataset = self.map( dataset, self.tokenize, @@ -112,12 +116,13 @@ def __call__(self, add_labels: bool = True) -> DatasetType: keep_in_memory=True, # bug occurs when not batched and not in memory, # subsequent ds.map calls are always batched, # regardless of `batched` argument - remove_columns=get_column_names(dataset), # assumes that input names + remove_columns=get_columns(dataset), # assumes that input names # and output names are disjoint num_proc=self.data_args.preprocessing_num_workers, load_from_cache_file=not self.data_args.overwrite_cache, desc="Tokenizing", ) + logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}") if self.data_args.concatenate_data: # postprocess: group text @@ -129,6 +134,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: load_from_cache_file=not self.data_args.overwrite_cache, desc="Concatenating data", ) + logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}") if add_labels: # postprocess: add labels @@ -140,10 +146,13 @@ def __call__(self, add_labels: bool = True) -> DatasetType: load_from_cache_file=not self.data_args.overwrite_cache, desc="Adding labels", ) + logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}") - elif self.PROMPT_KEY in get_column_names(dataset): + elif self.PROMPT_KEY in get_columns(dataset): dataset.remove_columns(self.PROMPT_KEY) + logger.debug("Removed prompt key") + logger.debug(f"Model kwargs after postprocessing: {get_columns(dataset)}") return dataset def load_dataset(self): @@ -167,6 +176,7 @@ def load_dataset(self): else self.data_args.dataset_name, ) + logger.debug(f"Loading dataset {self.data_args.dataset}") return get_raw_dataset( self.data_args, None, @@ -178,8 +188,8 @@ def load_dataset(self): @cached_property def preprocess(self) -> Union[Callable[[LazyRow], Any], None]: """ - The function must return keys which correspond to tokenizer kwargs, optionally - including PROMPT_KEY + The function must return keys which correspond to processor/tokenizer kwargs, + optionally including PROMPT_KEY """ preprocessing_func = self.data_args.preprocessing_func @@ -204,8 +214,9 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]: def rename_columns(self, dataset: DatasetType) -> DatasetType: # rename columns to match processor/tokenizer kwargs - column_names = get_column_names(dataset) + column_names = get_columns(dataset) if self.data_args.text_column in column_names and "text" not in column_names: + logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`") dataset = dataset.rename_column(self.data_args.text_column, "text") return dataset @@ -218,8 +229,11 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType: for key, param in signature.parameters.items() if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD) ) + logger.debug( + f"Found processor args `{tokenizer_args}`. Removing all other columns" + ) - column_names = get_column_names(dataset) + column_names = get_columns(dataset) return dataset.remove_columns(list(set(column_names) - set(tokenizer_args))) def tokenize(self, data: LazyRow) -> Dict[str, Any]: @@ -304,7 +318,7 @@ def map( return dataset -def get_column_names(dataset: DatasetType) -> List[str]: +def get_columns(dataset: DatasetType) -> List[str]: column_names = dataset.column_names if isinstance(column_names, dict): column_names = sum(column_names.values(), []) diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index d0c097978..0c108cd91 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -45,7 +45,7 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): remove_columns: Union[None, str, List] = field( default=None, - metadata={"help": "Column names to remove after preprocessing (depreciated)"}, + metadata={"help": "Column names to remove after preprocessing (deprecated)"}, ) preprocessing_func: Union[None, str, Callable] = field( @@ -53,7 +53,7 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): metadata={ "help": ( "Typically a function which applies a chat template. Can take the form " - "of iither a function to apply to the dataset, a name defined in " + "of either a function to apply to the dataset, a name defined in " "src/llmcompressor/transformers/utils/preprocessing_functions.py, or " "a path to a function definition of the form /path/to/file.py:func" ) From 7be0c88ca20ffcb05dd6ceb0364370d090606329 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Dec 2024 21:14:22 +0000 Subject: [PATCH 237/285] add back copyrights --- .../transformers/finetune/data/cnn_dailymail.py | 13 +++++++++++++ .../transformers/finetune/data/custom.py | 13 +++++++++++++ .../transformers/finetune/data/evolcodealpaca.py | 13 +++++++++++++ .../transformers/finetune/data/open_platypus.py | 13 +++++++++++++ .../transformers/finetune/data/ultrachat_200k.py | 13 +++++++++++++ 5 files changed, 65 insertions(+) diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 06ad3ecfa..473c82fd1 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from copy import deepcopy from typing import TYPE_CHECKING diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py index 7cff3c1d9..0f361fc95 100644 --- a/src/llmcompressor/transformers/finetune/data/custom.py +++ b/src/llmcompressor/transformers/finetune/data/custom.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from llmcompressor.transformers.finetune.data import TextGenerationDataset diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 932bfa54c..8421abcff 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from copy import deepcopy from typing import TYPE_CHECKING diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 3b25986ca..c113fb8d3 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from copy import deepcopy from typing import TYPE_CHECKING diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 47af48cdd..b2f8bf4ce 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from copy import deepcopy from typing import TYPE_CHECKING From bedbf8cbfc786dca221a73a4bb921ec20d275a98 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Dec 2024 21:20:49 +0000 Subject: [PATCH 238/285] correctly update helptext Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/data_args.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index 0c108cd91..7d0bc14ce 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -62,10 +62,7 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): data_collator: Callable[[Any], Any] = field( default_factory=lambda: DefaultDataCollator(), - metadata={ - "help": "For custom datasets only. The function to used to form a batch " - "from the dataset" - }, + metadata={"help": "The function to used to form a batch from the dataset"}, ) From 42f78925ff81f1da774be31a2679ce179a31d0e1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Dec 2024 21:48:31 +0000 Subject: [PATCH 239/285] do not remove prompt key Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 08ea83512..9736ca0e6 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -234,7 +234,9 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType: ) column_names = get_columns(dataset) - return dataset.remove_columns(list(set(column_names) - set(tokenizer_args))) + return dataset.remove_columns( + list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY])) + ) def tokenize(self, data: LazyRow) -> Dict[str, Any]: # separate prompt From 41396282ec7e4dca1b42987e15ee6e371ddf7222 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Dec 2024 21:53:42 +0000 Subject: [PATCH 240/285] add no copyright to hf files Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/tracing/llava.py | 1 + src/llmcompressor/transformers/tracing/mistral.py | 1 + src/llmcompressor/transformers/tracing/mllama.py | 1 + 3 files changed, 3 insertions(+) diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index b918f242a..2a80d2efb 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# vllm-project: no copyright """PyTorch Llava model.""" from functools import wraps diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py index 08f8d32d7..bbfa9d319 100644 --- a/src/llmcompressor/transformers/tracing/mistral.py +++ b/src/llmcompressor/transformers/tracing/mistral.py @@ -18,6 +18,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# vllm-project: no copyright """PyTorch Mistral model.""" import math diff --git a/src/llmcompressor/transformers/tracing/mllama.py b/src/llmcompressor/transformers/tracing/mllama.py index 301517618..955f7c270 100644 --- a/src/llmcompressor/transformers/tracing/mllama.py +++ b/src/llmcompressor/transformers/tracing/mllama.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# vllm-project: no copyright """PyTorch Mllama model.""" import math From 15fa27d1b9d3ddd10339a5ed9fb329517850ee79 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 23 Dec 2024 14:35:24 +0000 Subject: [PATCH 241/285] remove prompt key Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 9736ca0e6..ddc348b89 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -149,7 +149,7 @@ def __call__(self, add_labels: bool = True) -> DatasetType: logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}") elif self.PROMPT_KEY in get_columns(dataset): - dataset.remove_columns(self.PROMPT_KEY) + dataset = dataset.remove_columns(self.PROMPT_KEY) logger.debug("Removed prompt key") logger.debug(f"Model kwargs after postprocessing: {get_columns(dataset)}") From ae16da322b58b667a0af40759a5042c61fa538eb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 23 Dec 2024 14:59:33 +0000 Subject: [PATCH 242/285] do not process tokenized datasets, including adding labels Signed-off-by: Kyle Sayers --- .../transformers/finetune/runner.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 0841305d8..0a07c45eb 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -100,13 +100,19 @@ def _get_split_name(inp_str): else "custom" ) for split_name, split_str in splits.items(): - dataset_manager = TextGenerationDataset.load_from_registry( - registry_id, - data_args=self._data_args, - split=split_str, - processor=processor, - ) - tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels) + dataset = self._data_args.dataset + if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: + # dataset is already tokenized + tokenized_datasets[split_name] = dataset + else: + # dataset needs to be tokenized + dataset_manager = TextGenerationDataset.load_from_registry( + registry_id, + data_args=self._data_args, + split=split_str, + processor=processor, + ) + tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels) self.datasets = make_dataset_splits( tokenized_datasets, From c3a663abc6237a11d6afd82b35910ddf41b4e384 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 23 Dec 2024 17:47:59 +0000 Subject: [PATCH 243/285] rename classes so the saved config is the original class --- src/llmcompressor/transformers/tracing/__init__.py | 10 +++++++--- src/llmcompressor/transformers/tracing/llava.py | 4 ++-- src/llmcompressor/transformers/tracing/mistral.py | 2 +- src/llmcompressor/transformers/tracing/mllama.py | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 88be0fe88..c3c14a2d5 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,6 +1,10 @@ -from .llava import TracableLlavaForConditionalGeneration -from .mistral import TracableMistralForCausalLM -from .mllama import TracableMllamaForConditionalGeneration +from .llava import ( + LlavaForConditionalGeneration as TracableLlavaForConditionalGeneration, +) +from .mistral import MistralForCausalLM as TracableMistralForCausalLM +from .mllama import ( + MllamaForConditionalGeneration as TracableMllamaForConditionalGeneration, +) __all__ = [ "TracableLlavaForConditionalGeneration", diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index 2a80d2efb..7a71f2564 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -29,10 +29,10 @@ ) from transformers.models.mistral.configuration_mistral import MistralConfig -from .mistral import TracableMistralForCausalLM +from .mistral import MistralForCausalLM as TracableMistralForCausalLM -class TracableLlavaForConditionalGeneration(LlavaForConditionalGeneration): +class LlavaForConditionalGeneration(LlavaForConditionalGeneration): def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py index bbfa9d319..7a63099c3 100644 --- a/src/llmcompressor/transformers/tracing/mistral.py +++ b/src/llmcompressor/transformers/tracing/mistral.py @@ -1090,7 +1090,7 @@ def _update_causal_mask( return causal_mask -class TracableMistralForCausalLM(MistralPreTrainedModel, GenerationMixin): +class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} diff --git a/src/llmcompressor/transformers/tracing/mllama.py b/src/llmcompressor/transformers/tracing/mllama.py index 955f7c270..512ba4227 100644 --- a/src/llmcompressor/transformers/tracing/mllama.py +++ b/src/llmcompressor/transformers/tracing/mllama.py @@ -2279,7 +2279,7 @@ def forward( """The Mllama model which consists of a vision encoder and a language model.""", MLLAMA_START_DOCSTRING, ) -class TracableMllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): +class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): _supports_quantized_cache = ( False # quant cache not supported in encoder-decoder setting ) From e71f4e582cdbe81e3cc3f70847058a30fb1bbaf1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 23 Dec 2024 15:55:23 -0500 Subject: [PATCH 244/285] remove default chat template Signed-off-by: Kyle Sayers --- .../transformers/finetune/data/flickr_30k.py | 21 ------------------- .../finetune/data/ultrachat_200k.py | 21 ------------------- 2 files changed, 42 deletions(-) diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 2c55bf42d..1df46921e 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -16,33 +16,12 @@ class Flickr30K(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - DEFAULT_CHAT_TEMPLATE = ( - "{% for message in messages %}\n" - "{% if message['role'] == 'user' %}\n" - "{{ '<|user|>\n' + message['content'] + eos_token }}\n" - "{% elif message['role'] == 'system' %}\n" - "{{ '<|system|>\n' + message['content'] + eos_token }}\n" - "{% elif message['role'] == 'assistant' %}\n" - "{{ '<|assistant|>\n' + message['content'] + eos_token }}\n" - "{% endif %}\n" - "{% if loop.last and add_generation_prompt %}\n" - "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" - ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "lmms-lab/flickr30k" super().__init__(data_args=data_args, split=split, processor=processor) - if ( - self.tokenizer is not None - and getattr(self.tokenizer, "chat_template", None) is None - ): - # note that since tokenizer is a member of processor, - # this change affects processor.apply_chat_template - self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE - def dataset_template(self, sample): messages = [ { diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index b2f8bf4ce..f541dee00 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -31,19 +31,6 @@ class UltraChatDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - DEFAULT_CHAT_TEMPLATE = ( - "{% for message in messages %}\n" - "{% if message['role'] == 'user' %}\n" - "{{ '<|user|>\n' + message['content'] + eos_token }}\n" - "{% elif message['role'] == 'system' %}\n" - "{{ '<|system|>\n' + message['content'] + eos_token }}\n" - "{% elif message['role'] == 'assistant' %}\n" - "{{ '<|assistant|>\n' + message['content'] + eos_token }}\n" - "{% endif %}\n" - "{% if loop.last and add_generation_prompt %}\n" - "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" - ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "HuggingFaceH4/ultrachat_200k" @@ -54,14 +41,6 @@ def __init__(self, data_args: "DataArgs", split: str, processor: Processor): super().__init__(data_args=data_args, split=split, processor=processor) - if ( - self.tokenizer is not None - and getattr(self.tokenizer, "chat_template", None) is None - ): - # note that since tokenizer is a member of processor, - # this change affects processor.apply_chat_template - self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE - def dataset_template(self, sample): messages = sample["messages"] if messages[0]["role"] != "system": From 0195fab07a7d204d82c3f73a2c9c7e27740ef544 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 23 Dec 2024 17:32:54 -0500 Subject: [PATCH 245/285] support llava-1.5 via installing metadata Signed-off-by: Kyle Sayers --- examples/multimodal_vision/llava.py | 54 +++++++++++++++++++ .../transformers/tracing/llava.py | 49 +++++++++++++++-- .../transformers/utils/data_collator.py | 9 ++++ 3 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 examples/multimodal_vision/llava.py diff --git a/examples/multimodal_vision/llava.py b/examples/multimodal_vision/llava.py new file mode 100644 index 000000000..68653f182 --- /dev/null +++ b/examples/multimodal_vision/llava.py @@ -0,0 +1,54 @@ +from transformers import AutoProcessor + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration +from llmcompressor.transformers.utils.data_collator import llava_data_collator + +# Load model. +model_id = "llava-hf/llava-1.5-7b-hf" +model = TracableLlavaForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": "test[:512]"} +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W8A8", + ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], + sequential_targets=["LlamaDecoderLayer"], + ), +] + +# Perform oneshot +oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + data_collator=llava_data_collator, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index 7a71f2564..e77fda2df 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -16,7 +16,6 @@ # vllm-project: no copyright """PyTorch Llava model.""" -from functools import wraps from typing import List, Optional, Tuple, Union import torch @@ -28,10 +27,46 @@ logger, ) from transformers.models.mistral.configuration_mistral import MistralConfig +from transformers.utils.fx import HFProxy from .mistral import MistralForCausalLM as TracableMistralForCausalLM +def maybe_install_metadata_image_features( + image_features: Union[torch.Tensor, HFProxy], + pixel_values: Union[torch.Tensor, HFProxy], + config: LlavaConfig, +): + if isinstance(image_features, HFProxy): + # (num_images, image_length, embed_dim) + num_images = pixel_values._metadata.size(0) + image_length = config.image_seq_length + embed_dim = config.vision_config.intermediate_size + + original_fn = image_features.tracer.patched_torch_methods["empty"][1] + metadata = original_fn( + (num_images, image_length, embed_dim), device=torch.device("meta") + ) + image_features.install_metadata(metadata) + + return image_features + + +def maybe_install_metadata_inputs_embeds( + inputs_embeds_masked: Union[torch.Tensor, HFProxy], + inputs_embeds: Union[torch.Tensor, HFProxy], + special_image_mask: Union[torch.Tensor, HFProxy], + image_features: Union[torch.Tensor, HFProxy], +): + if isinstance(inputs_embeds_masked, HFProxy): + metadata = inputs_embeds._metadata.masked_scatter( + special_image_mask._metadata.to(bool), image_features._metadata + ) + inputs_embeds_masked.install_metadata(metadata) + + return inputs_embeds + + class LlavaForConditionalGeneration(LlavaForConditionalGeneration): def __init__(self, config: LlavaConfig): super().__init__(config) @@ -53,7 +88,6 @@ def __init__(self, config: LlavaConfig): self.__class__.__name__ = "LlavaForConditionalGeneration" - @wraps(LlavaForConditionalGeneration.forward) def forward( self, input_ids: torch.LongTensor = None, @@ -127,6 +161,10 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, ) + image_features = maybe_install_metadata_image_features( + image_features, pixel_values, self.config + ) + if legacy_processing: logger.warning_once( "Expanding inputs for image tokens in LLaVa should be done in processing. " @@ -202,10 +240,15 @@ def forward( image_features = image_features.to( inputs_embeds.device, inputs_embeds.dtype ) - inputs_embeds = inputs_embeds.masked_scatter( + inputs_embeds_masked = inputs_embeds.masked_scatter( special_image_mask, image_features ) + inputs_embeds_masked = maybe_install_metadata_inputs_embeds( + inputs_embeds_masked, inputs_embeds, special_image_mask, image_features + ) + inputs_embeds = inputs_embeds_masked + outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/llmcompressor/transformers/utils/data_collator.py b/src/llmcompressor/transformers/utils/data_collator.py index 454ada61e..360fe2f82 100644 --- a/src/llmcompressor/transformers/utils/data_collator.py +++ b/src/llmcompressor/transformers/utils/data_collator.py @@ -24,6 +24,15 @@ def pixtral_data_collator(batch): } +def llava_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "pixel_values": torch.tensor(batch[0]["pixel_values"]), + } + + def qwen2_vl_data_collator(batch): assert len(batch) == 1 return { From 148e61741948345261afb6f0ca08fe41579f7066 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 27 Dec 2024 03:35:04 +0000 Subject: [PATCH 246/285] account for models which improperly do not override the abstract methods Signed-off-by: Kyle Sayers --- .../transformers/sparsification/compressed_tensors_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index eba5c5882..4c1e798b2 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -236,6 +236,10 @@ def patch_tied_tensors_bug(model: torch.nn.Module): input_embed = model.get_input_embeddings() output_embed = model.get_output_embeddings() + if input_embed is None or output_embed is None: + # some models fail to properly override the abstract methods + return + if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight): for module in (input_embed, output_embed): if not is_module_offloaded(module): From e5dd5827d8d36e82cbf9372c8b4a5f40bc64554d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 27 Dec 2024 04:22:35 +0000 Subject: [PATCH 247/285] add ChatGLMForConditionalGeneration Signed-off-by: Kyle Sayers --- .../transformers/tracing/__init__.py | 2 + .../tracing/glm/configuration_chatglm.py | 66 + .../tracing/glm/modeling_chatglm.py | 1332 +++++++++++++++++ .../transformers/tracing/glm/visual.py | 180 +++ .../transformers/utils/data_collator.py | 10 + 5 files changed, 1590 insertions(+) create mode 100644 src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py create mode 100644 src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py create mode 100644 src/llmcompressor/transformers/tracing/glm/visual.py diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index c3c14a2d5..3bb896d1f 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -5,9 +5,11 @@ from .mllama import ( MllamaForConditionalGeneration as TracableMllamaForConditionalGeneration, ) +from .glm.modeling_chatglm import ChatGLMForConditionalGeneration __all__ = [ "TracableLlavaForConditionalGeneration", "TracableMllamaForConditionalGeneration", "TracableMistralForCausalLM", + "ChatGLMForConditionalGeneration", ] diff --git a/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py b/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py new file mode 100644 index 000000000..de54e92e2 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py @@ -0,0 +1,66 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + classifier_dropout=None, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + rope_ratio=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + pre_seq_len=None, + prefix_projection=False, + boi_token_id=None, + eoi_token_id=None, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.classifier_dropout = classifier_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.rope_ratio = rope_ratio + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id + super().__init__(**kwargs) \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py new file mode 100644 index 000000000..462ae200a --- /dev/null +++ b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py @@ -0,0 +1,1332 @@ +""" PyTorch GLM-4V model. """ +import math +import sys +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Dict, Any + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging, is_torch_npu_available +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .visual import EVA2CLIPModel +from .configuration_chatglm import ChatGLMConfig + +try: + from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available + + if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +except: + pass + +from torch.fx import wrap + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin' and not is_torch_npu_available(): + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +LANGUAGE_TOKEN_TYPE = 0 +VISION_TOKEN_TYPE = 1 + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" +_CONFIG_FOR_DOC = "ChatGLMConfig" + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 198] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + self.rope_ratio = rope_ratio + + def impl(self, seq_length: int, dim: int, device: torch.device, dtype: torch.dtype): + base = 10000 * self.rope_ratio + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + seq = torch.arange(seq_length, device=inv_freq.device, dtype=torch.float32) + freqs = torch.outer(seq, inv_freq) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + emb = torch.cat((freqs, freqs), dim=-1) + return emb + + def forward_impl( + self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + base = base * self.rope_ratio + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + if self.original_impl: + return self.forward_impl( + max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device + ) + else: + return self.impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [b, np, sq, hn] + b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:, :sq] + xshaped = x.reshape(b, np, sq, rot_dim // 2, 2) + rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.transpose(1, 2).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) + + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.transpose(1, 2).contiguous() + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + + return context_layer + +class SdpaAttention(CoreAttention): + def forward(self, query_layer, key_layer, value_layer, attention_mask): + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + is_causal=True, + dropout_p=self.config.attention_dropout if self.training else 0.0) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask, + dropout_p=self.config.attention_dropout if self.training else 0.0) + context_layer = context_layer.transpose(1, 2).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + return context_layer + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 +class FlashAttention2(CoreAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward(self, query_states, key_states, value_states, attention_mask): + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + batch_size, query_length = query_states.shape[:2] + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + dropout = self.config.attention_dropout if self.training else 0.0 + # Contains at least one padding token in the sequence + if attention_mask is not None: + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=None, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal + ) + attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous() + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), + indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +CORE_ATTENTION_CLASSES = { + "eager": CoreAttention, + "sdpa": SdpaAttention, + "flash_attention_2": FlashAttention2 +} + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + self.original_rope = config.original_rope + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, **_config_to_kwargs(config) + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, + device=device, **_config_to_kwargs(config) + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True + ): + # hidden_states: [b, sq, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]] + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=2) + value_layer = torch.cat((cache_v, value_layer), dim=2) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(2) + key_layer = key_layer.expand( + -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:] + ) + value_layer = value_layer.unsqueeze(2) + value_layer = value_layer.expand( + -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:] + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache, + use_reentrant=False + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_embeds, past_key_values, padding_mask=None): + batch_size, seq_length, embed_size = input_embeds.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[2] + if past_length: + full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, + device=input_embeds.device), full_attention_mask), dim=-1) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def get_multimodal_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +@wrap +def is_empty(images_list: Optional[List[List[torch.Tensor]]]): + if images_list is None or len(images_list) == 0: + return True + for image_list in images_list: + if image_list is not None: + return False + return True + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, + original_impl=config.original_rope, + device=device, dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, + dtype=config.torch_dtype, **init_kwargs) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + self.vision = EVA2CLIPModel(config) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def set_input_embeddings(self, value): + self.embedding.word_embeddings = value + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids: torch.LongTensor = None, + images: torch.Tensor = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """take care of image_encode, position_ids and (attention_mask = None is fine)""" + + # generate mode with past_key_values. the image features are already mapped + if past_key_values is None: + # not allow for inputs_embeds, because we want to process image feature + assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}" + if not is_empty(images): # multi-modality + image_size: int = self.config.vision_config['image_size'] + patch_size: int = self.config.vision_config['patch_size'] + num_patches = (image_size // patch_size // 2) ** 2 + assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}" + inputs_embeds = self.embedding(input_ids) + + images = images.to(dtype=inputs_embeds.dtype) + images_features = self.vision(images) + + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=inputs_embeds.device) + new_input_embeds, new_position_ids = [], [] + + for i in range(len(input_ids)): + input_id = input_ids[i].tolist() + boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( + self.config.eoi_token_id) + #assert eoi_token_pos - boi_token_pos == 2 + new_input_embeds.append(torch.cat( + (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device), + inputs_embeds[i, eoi_token_pos + 1:]))) + new_position_ids.append(torch.cat( + (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches), + position_ids[i, eoi_token_pos:]) + )) + inputs_embeds = torch.stack(new_input_embeds, dim=0) + position_ids = torch.stack(new_position_ids, dim=0) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask], dim=-1) + + if full_attention_mask is None: + if True: #if (attention_mask is not None and not attention_mask.all().item()) or (past_key_values and seq_length != 1): + if True: #if self.training: + # https://github.com/THUDM/GLM-4/issues/264 + new_input_ids, new_attention_mask = [], [] + for i in range(len(input_ids)): + input_id = input_ids[i].tolist() + boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(self.config.eoi_token_id) + #assert eoi_token_pos - boi_token_pos == 2 + + new_attention_mask.append(torch.cat( + (attention_mask[i, :boi_token_pos + 1], torch.ones(num_patches).to(attention_mask.device), + attention_mask[i, eoi_token_pos:]))) + + new_input_ids.append(torch.cat( + (input_ids[i, :boi_token_pos + 1], input_ids[i, -1].repeat(num_patches), + input_ids[i, eoi_token_pos:]))) + + attention_mask = torch.stack(new_attention_mask, dim=0) + input_ids = torch.stack(new_input_ids, dim=0) + inputs_embeds = self.embedding(input_ids) + + full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def _history_to_prompt(history, query): + prompt = '' + flag = False + for i, (old_query, response) in enumerate(history): + prompt += ('<|user|>' if flag else '') + old_query + "<|assistant|>" + response + "<|endoftext|>" + flag = True + prompt += '{}{}<|assistant|>'.format('<|user|>' if flag else '', query) + return prompt + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + cache_name, cache = self._extract_past_from_model_output(outputs) + model_kwargs[cache_name] = cache + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + images: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + is_first_forward: bool = True, + **kwargs + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if attention_mask is not None: + image_size: int = self.config.vision_config['image_size'] + patch_size: int = self.config.vision_config['patch_size'] + num_patches = (image_size // patch_size // 2) ** 2 + new_attention_masks = [] + + # if not image, use this default id + eoi_token_pos = 6 + boi_token_pos = 4 + + for i in range(len(input_ids)): + input_id = input_ids[i].tolist() + if not is_empty(images): + boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( + self.config.eoi_token_id) + #assert eoi_token_pos - boi_token_pos == 2 + new_attention_masks.append(torch.cat( + (attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches), + attention_mask[i, eoi_token_pos:]) + )) + attention_mask = torch.stack(new_attention_masks, dim=0) + if not is_first_forward: + if past_key_values is not None: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "images": images, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + "use_cache": use_cache + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + images: List[List[torch.Tensor]] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + images=images, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[:, -1:] + lm_logits = self.transformer.output_layer(hidden_states) + + loss = None + if labels is not None: + new_labels = [] + for i in range(len(input_ids)): + input_id = input_ids[i].tolist() + boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( + self.config.eoi_token_id) + #assert eoi_token_pos - boi_token_pos == 2 + + new_labels.append(torch.cat( + ( + labels[i, :boi_token_pos + 1], + torch.tensor([-100]).to(labels.device).to(labels.dtype).repeat(1600), + labels[i, eoi_token_pos:]))) + + labels = torch.stack(new_labels, dim=0) + lm_logits = lm_logits.to(torch.float32) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + +class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.num_labels = config.num_labels + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + + self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) + if config.classifier_dropout is not None: + self.dropout = nn.Dropout(config.classifier_dropout) + else: + self.dropout = None + self.config = config + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + full_attention_mask=full_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + pooled_hidden_states = hidden_states[-1] + if self.dropout is not None: + pooled_hidden_states = self.dropout(pooled_hidden_states) + logits = self.classifier_head(pooled_hidden_states) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze().float(), labels.squeeze()) + else: + loss = loss_fct(logits.float(), labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/glm/visual.py b/src/llmcompressor/transformers/tracing/glm/visual.py new file mode 100644 index 000000000..6a88b747c --- /dev/null +++ b/src/llmcompressor/transformers/tracing/glm/visual.py @@ -0,0 +1,180 @@ +import torch +from torch import nn +from argparse import Namespace +import torch.nn.functional as F +from transformers.activations import ACT2FN +import math +from torch.nn import LayerNorm + + +def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True): + if scaling_attention_score: + query_layer = query_layer / math.sqrt(query_layer.shape[-1]) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_probs = F.softmax(attention_scores, dim=-1) + + context_layer = torch.matmul(attention_probs, value_layer) + return context_layer + + +def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True): + if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score: + # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, + attn_mask=None, + dropout_p=0., + is_causal=False + ) + return attn_output + else: + return standard_attention( + query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score + ) + + +class PatchEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, + stride=config.patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size) + + def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)": + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_heads + head_dim = config.hidden_size // config.num_heads + self.scale = head_dim ** -0.5 + self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.output_dropout = torch.nn.Dropout(config.dropout_prob) + + def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)": + B, L, _ = x.shape + qkv = self.query_key_value(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D + q, k, v = qkv[0], qkv[1], qkv[2] + + out = attention_fn_default( + q, k, v + ) + output = self.dense(out.transpose(1, 2).reshape(B, L, -1)) + output = self.output_dropout(output) + return output + + def attention(self, q, k, v): + attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1)) + attn_weights = attn_weights.softmax(dim=-1) + output = torch.matmul(attn_weights, v) + return output + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class TransformerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Attention(config) + self.mlp = MLP(config) + self.post_attention_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + attention_input = hidden_states + attention_output = self.input_layernorm(self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + + # https://github.com/THUDM/GLM-4/issues/350 + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)).to(mlp_input.device) + output = mlp_input + mlp_output + return output + + +class Transformer(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states): + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + def __init__(self, config, in_features): + super().__init__() + self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False) + self.norm1 = nn.LayerNorm(config.hidden_size) + self.act1 = nn.GELU() + self.act2 = nn.functional.silu + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False) + self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=False) + + def forward(self, x): + x = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + def __init__(self, config): + super().__init__() + vision_config = Namespace(**config.vision_config) + self.patch_embedding = PatchEmbedding(vision_config) + self.transformer = Transformer(vision_config) + self.linear_proj = GLU(config, in_features=config.hidden_size) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2, + stride=2) + self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)": + x = self.patch_embedding(images) + x = self.transformer(x) + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s ** 0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + + # https://github.com/THUDM/GLM-4/issues/350 + boi = self.boi.expand(x.shape[0], -1, -1).to(x.device) + eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x \ No newline at end of file diff --git a/src/llmcompressor/transformers/utils/data_collator.py b/src/llmcompressor/transformers/utils/data_collator.py index 360fe2f82..9a566647d 100644 --- a/src/llmcompressor/transformers/utils/data_collator.py +++ b/src/llmcompressor/transformers/utils/data_collator.py @@ -41,3 +41,13 @@ def qwen2_vl_data_collator(batch): "pixel_values": torch.tensor(batch[0]["pixel_values"]), "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), } + + +def glm_data_collator(batch): + assert len(batch) == 1 + return { + "input_ids": torch.LongTensor(batch[0]["input_ids"]), + "attention_mask": torch.tensor(batch[0]["attention_mask"]), + "position_ids": torch.tensor(batch[0]["position_ids"]), + "images": torch.tensor(batch[0]["images"]), + } \ No newline at end of file From 5303df2a9145b39bf195682b7ea8402aec267360 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 27 Dec 2024 16:17:23 +0000 Subject: [PATCH 248/285] list of unfixable errors Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/gptq/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index d01584aa6..3b0e15cb4 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -217,6 +217,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer pipeline model_name = state.model.__class__.__name__ input_names = state.data.calib.dataset.column_names + unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) try: run_sequential( state.model, @@ -230,7 +231,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: except Exception as exception: if isinstance(exception, torch.fx.proxy.TraceError): warnings.warn(f"Failed to trace {model_name} with inputs {input_names}") - if isinstance(exception, torch.OutOfMemoryError): + if isinstance(exception, unfixable_errors): raise exception warnings.warn("Falling back to layer_sequential pipeline") @@ -246,7 +247,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: except Exception as exception: if isinstance(exception, TypeError): warnings.warn(f"{model_name} fails layer-wise assumptions") - if isinstance(exception, torch.OutOfMemoryError): + if isinstance(exception, unfixable_errors): raise exception warnings.warn( From 14cbc979ca853b365d3e334bb99892c3280fcfd9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 29 Dec 2024 11:19:19 -0500 Subject: [PATCH 249/285] add glm license, style Signed-off-by: Kyle Sayers --- pyproject.toml | 2 +- .../transformers/tracing/__init__.py | 2 +- .../transformers/tracing/glm/LICENSE | 84 +++++++++++++++++++ .../tracing/glm/configuration_chatglm.py | 1 + .../tracing/glm/modeling_chatglm.py | 33 +++++--- .../transformers/tracing/glm/visual.py | 10 ++- .../transformers/utils/data_collator.py | 2 +- 7 files changed, 117 insertions(+), 17 deletions(-) create mode 100644 src/llmcompressor/transformers/tracing/glm/LICENSE diff --git a/pyproject.toml b/pyproject.toml index 98661216d..90164e056 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ profile = "black" files = "src/guidellm" [tool.ruff] -exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/pytorch/tracing"] +exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/transformers/tracing"] lint.select = ["E", "F", "W"] [tool.flake8] diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 3bb896d1f..4d846c536 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,3 +1,4 @@ +from .glm.modeling_chatglm import ChatGLMForConditionalGeneration from .llava import ( LlavaForConditionalGeneration as TracableLlavaForConditionalGeneration, ) @@ -5,7 +6,6 @@ from .mllama import ( MllamaForConditionalGeneration as TracableMllamaForConditionalGeneration, ) -from .glm.modeling_chatglm import ChatGLMForConditionalGeneration __all__ = [ "TracableLlavaForConditionalGeneration", diff --git a/src/llmcompressor/transformers/tracing/glm/LICENSE b/src/llmcompressor/transformers/tracing/glm/LICENSE new file mode 100644 index 000000000..7b7c19f56 --- /dev/null +++ b/src/llmcompressor/transformers/tracing/glm/LICENSE @@ -0,0 +1,84 @@ +The glm-4-9b License + +1. 定义 + +“许可方”是指分发其软件的 glm-4-9b 模型团队。 +“软件”是指根据本许可提供的 glm-4-9b 模型参数。 + +2. 许可授予 + +根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。 +本许可允许您免费使用本仓库中的所有开源模型进行学术研究,对于希望将模型用于商业目的的用户,需在[这里](https://open.bigmodel.cn/mla/form)完成登记。经过登记的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。 +上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 +如果您分发或提供 THUDM / 智谱AI 关于 glm-4 开源模型的材料(或其任何衍生作品),或使用其中任何材料(包括 glm-4 系列的所有开源模型)的产品或服务,您应: + +(A) 随任何此类 THUDM / 智谱AI 材料提供本协议的副本; +(B) 在相关网站、用户界面、博客文章、关于页面或产品文档上突出显示 “Built with glm-4”。 +如果您使用 THUDM / 智谱AI的 glm-4 开源模型的材料来创建、训练、微调或以其他方式改进已分发或可用的 AI 模型,您还应在任何此类 AI 模型名称的开头添加 “glm-4”。 + +3. 限制 + +您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 +您不得利用本软件从事任何危害国家安全和国家统一,危害社会公共利益及公序良俗,侵犯他人商业秘密、知识产权、名誉权、肖像权、财产权等权益的行为。 +您在使用中应遵循使用地所适用的法律法规政策、道德规范等要求。 + +4. 免责声明 + +本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 +在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 +软件。 + +5. 责任限制 + +除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 +或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 + +6. 争议解决 + +本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 +请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 + +1. Definitions + +“Licensor” means the glm-4-9b Model Team that distributes its Software. +“Software” means the glm-4-9b model parameters made available under this license. + +2. License + +Under the terms and conditions of this license, the Licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. +This license allows you to use all open source models in this repository for free for academic research. For users who wish to use the models for commercial purposes, please do so [here](https://open.bigmodel.cn/mla/form) +Complete registration. Registered users are free to use this model for commercial activities, but must comply with all terms and conditions of this license. +The copyright notice and this license notice shall be included in all copies or substantial portions of the Software. +If you distribute or provide THUDM / Zhipu AI materials on the glm-4 open source model (or any derivative works thereof), or products or services that use any materials therein (including all open source models of the glm-4 series), you should: + +(A) Provide a copy of this Agreement with any such THUDM/Zhipu AI Materials; +(B) Prominently display "Built with glm-4" on the relevant website, user interface, blog post, related page or product documentation. +If you use materials from THUDM/Zhipu AI's glm-4 model to create, train, operate, or otherwise improve assigned or available AI models, you should also add "glm-4" to the beginning of any such AI model name. + +3. Restrictions + +You are not allowed to use, copy, modify, merge, publish, distribute, copy or create all or part of the derivative works of this software for any military or illegal purposes. +You are not allowed to use this software to engage in any behavior that endangers national security and unity, endangers social public interests and public order, infringes on the rights and interests of others such as trade secrets, intellectual property rights, reputation rights, portrait rights, and property rights. +You should comply with the applicable laws, regulations, policies, ethical standards, and other requirements in the place of use during use. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, +NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, +INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED +OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute +arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and +copyright, please contact us at license@zhipuai.cn. \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py b/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py index de54e92e2..c9783fcac 100644 --- a/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py +++ b/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py @@ -1,3 +1,4 @@ +# flake8: noqa from transformers import PretrainedConfig diff --git a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py index 462ae200a..d22e88d99 100644 --- a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py +++ b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py @@ -1,33 +1,46 @@ +# flake8: noqa """ PyTorch GLM-4V model. """ import math import sys +from typing import Any, Dict, List, Optional, Tuple, Union + import torch -import torch.utils.checkpoint import torch.nn.functional as F +import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Dict, Any - +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import ( + GenerationConfig, + LogitsProcessorList, + ModelOutput, + StoppingCriteriaList, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging, is_torch_npu_available -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput +from transformers.utils import is_torch_npu_available, logging -from .visual import EVA2CLIPModel from .configuration_chatglm import ChatGLMConfig +from .visual import EVA2CLIPModel try: - from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available + from transformers.utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + ) if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn.bert_padding import ( # noqa + index_first_axis, + pad_input, + unpad_input, + ) except: pass diff --git a/src/llmcompressor/transformers/tracing/glm/visual.py b/src/llmcompressor/transformers/tracing/glm/visual.py index 6a88b747c..9f01098dd 100644 --- a/src/llmcompressor/transformers/tracing/glm/visual.py +++ b/src/llmcompressor/transformers/tracing/glm/visual.py @@ -1,10 +1,12 @@ -import torch -from torch import nn +# flake8: noqa +import math from argparse import Namespace + +import torch import torch.nn.functional as F -from transformers.activations import ACT2FN -import math +from torch import nn from torch.nn import LayerNorm +from transformers.activations import ACT2FN def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True): diff --git a/src/llmcompressor/transformers/utils/data_collator.py b/src/llmcompressor/transformers/utils/data_collator.py index 9a566647d..bd4327b01 100644 --- a/src/llmcompressor/transformers/utils/data_collator.py +++ b/src/llmcompressor/transformers/utils/data_collator.py @@ -50,4 +50,4 @@ def glm_data_collator(batch): "attention_mask": torch.tensor(batch[0]["attention_mask"]), "position_ids": torch.tensor(batch[0]["position_ids"]), "images": torch.tensor(batch[0]["images"]), - } \ No newline at end of file + } From ff470b3ec2d84dbf59d9df819a2e9084362eb654 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 1 Jan 2025 20:48:52 -0500 Subject: [PATCH 250/285] add suggestion to use offload_hessians Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/gptq/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 3b0e15cb4..550b9849b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -252,7 +252,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: warnings.warn( "Falling back to basic pipeline, which requires extra memory and " - "may result in decreased accuracy" + "may result in decreased accuracy. Consider using " + "`offload_hessians=True`" ) run_basic(state.model, state.data.calib) return True From c1c3eaae56d65ef71c425ae7b48feec1ee5cf4e5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 1 Jan 2025 21:35:57 -0500 Subject: [PATCH 251/285] update names and comments Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/sequential/helpers.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index e243ee4c4..e8deeb9f3 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -81,7 +81,7 @@ def get_tracer( module for module in model.modules() if has_offloaded_params(module) ) - class PiecewiseTracer(HFTracer): + class SequentialTracer(HFTracer): def create_arg(self, a: Any) -> Argument: if isinstance(a, PretrainedConfig): kwargs = {k: self.create_arg(v) for k, v in a.to_dict().items()} @@ -99,7 +99,7 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: or super().is_leaf_module(module, module_qualified_name) ) - return PiecewiseTracer() + return SequentialTracer() def populate_concrete_args(model: Module, sample_input: Dict) -> Dict: @@ -166,9 +166,10 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List if remaining_indegrees[user] == 0: queue.append(user) - # a perfect solution would involve implicitly consolodating partition indices so - # that each node is assigned to the maximum partition possible (in order to delay - # execution as long as possible), but this covers the most costly case (get_attr) + # a perfect implementation would involve implicitly consolidating partition indices + # so that each node is assigned to the maximum partition possible (in order to delay + # execution as long as possible), but the current implementation covers the most + # common and costly case (get_attr) for node in graph.graph.find_nodes(op="get_attr"): user_partitions = [] for user in node.users: From e5af728a99387b73b307d9dda49d3fe2f0cea87e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 1 Jan 2025 21:45:51 -0500 Subject: [PATCH 252/285] change tqdm description, add comment Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/layer_sequential/pipeline.py | 3 ++- src/llmcompressor/pipelines/sequential/pipeline.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 334269e83..73ef421f9 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -41,13 +41,14 @@ def run_pipeline( layers = match_modules(model, sequential_targets) with calibration_forward_context(model): + # prepare intermediates cache intermediates = capture_first_layer_intermediates(model, layers, dataloader) num_layers = len(layers) for layer_index, layer in enumerate(layers): # prepare tqdm description texts calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" - prop_desc = f"({layer_index + 1}/{num_layers}): Propagate" + prop_desc = f"({layer_index + 1}/{num_layers}): Propagating" if propagate_error: # do an preliminary pass to trigger modifier hooks diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 629c0aeb5..ca7caf7a0 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -52,7 +52,7 @@ def run_pipeline( for subgraph_index, subgraph in enumerate(subgraphs): # prepare tqdm description texts calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" - prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagate" + prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" # compile subgraph forward function forward_function = subgraph.compile_forward() From 8fd93a7228b072899d9459f46c15f8b019dca56e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 1 Jan 2025 21:49:09 -0500 Subject: [PATCH 253/285] add no vllm copyright to glm Signed-off-by: Kyle Sayers --- .../transformers/tracing/glm/configuration_chatglm.py | 1 + src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py | 1 + src/llmcompressor/transformers/tracing/glm/visual.py | 1 + 3 files changed, 3 insertions(+) diff --git a/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py b/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py index c9783fcac..2487be8ce 100644 --- a/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py +++ b/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py @@ -1,4 +1,5 @@ # flake8: noqa +# vllm-project: no copyright from transformers import PretrainedConfig diff --git a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py index d22e88d99..45fcff21b 100644 --- a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py +++ b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py @@ -1,4 +1,5 @@ # flake8: noqa +# vllm-project: no copyright """ PyTorch GLM-4V model. """ import math import sys diff --git a/src/llmcompressor/transformers/tracing/glm/visual.py b/src/llmcompressor/transformers/tracing/glm/visual.py index 9f01098dd..1c2dd7fc6 100644 --- a/src/llmcompressor/transformers/tracing/glm/visual.py +++ b/src/llmcompressor/transformers/tracing/glm/visual.py @@ -1,4 +1,5 @@ # flake8: noqa +# vllm-project: no copyright import math from argparse import Namespace From 8e5f693e2b1ad4042e1a656fa53eec9a92e5a25b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 1 Jan 2025 22:26:34 -0500 Subject: [PATCH 254/285] update comments, remove unnecessary default values --- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 7efe21b94..decb3a33d 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -32,15 +32,14 @@ def make_empty_hessian( def accumulate_hessian( inp: torch.Tensor, module: torch.nn.Module, - H: Optional[torch.Tensor] = None, - num_samples: int = 1, + H: Optional[torch.Tensor], + num_samples: int, ) -> Tuple[torch.Tensor, int]: inp = inp.to(device=H.device) if len(inp.shape) == 2: inp = inp.unsqueeze(0) - num_added = inp.shape[0] # note this is the number of dataset samples, not - # multiplied by the sequence length + num_added = inp.shape[0] if isinstance(module, (torch.nn.Linear, transformers.Conv1D)): if len(inp.shape) == 3: From 7ba6f60498d4a0c9c6894c69e9e375522b7d9ac1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 03:02:23 +0000 Subject: [PATCH 255/285] rename examples to have _example suffix Signed-off-by: Kyle Sayers --- examples/multimodal_vision/{llava.py => llava_example.py} | 0 examples/multimodal_vision/{mllama.py => mllama_example.py} | 0 examples/multimodal_vision/{pixtral.py => pixtral_example.py} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename examples/multimodal_vision/{llava.py => llava_example.py} (100%) rename examples/multimodal_vision/{mllama.py => mllama_example.py} (100%) rename examples/multimodal_vision/{pixtral.py => pixtral_example.py} (100%) diff --git a/examples/multimodal_vision/llava.py b/examples/multimodal_vision/llava_example.py similarity index 100% rename from examples/multimodal_vision/llava.py rename to examples/multimodal_vision/llava_example.py diff --git a/examples/multimodal_vision/mllama.py b/examples/multimodal_vision/mllama_example.py similarity index 100% rename from examples/multimodal_vision/mllama.py rename to examples/multimodal_vision/mllama_example.py diff --git a/examples/multimodal_vision/pixtral.py b/examples/multimodal_vision/pixtral_example.py similarity index 100% rename from examples/multimodal_vision/pixtral.py rename to examples/multimodal_vision/pixtral_example.py From 435cf0dcdd628d7da0eae287d1c9748435a16528 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 04:30:01 +0000 Subject: [PATCH 256/285] update all list Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/utils/data_collator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/utils/data_collator.py b/src/llmcompressor/transformers/utils/data_collator.py index bd4327b01..930b06696 100644 --- a/src/llmcompressor/transformers/utils/data_collator.py +++ b/src/llmcompressor/transformers/utils/data_collator.py @@ -1,6 +1,12 @@ import torch -__all__ = ["mllama_data_collator", "pixtral_data_collator"] +__all__ = [ + "mllama_data_collator", + "pixtral_data_collator", + "llava_data_collator", + "qwen2_vl_data_collator", + "glm_data_collator", +] def mllama_data_collator(batch): From 0d25307849ac3a13a6fb722a369b76df332f32d1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 04:30:58 +0000 Subject: [PATCH 257/285] update examples to use w4a16 Signed-off-by: Kyle Sayers --- examples/multimodal_vision/llava_example.py | 2 +- examples/multimodal_vision/mllama_example.py | 2 +- examples/multimodal_vision/pixtral_example.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/multimodal_vision/llava_example.py b/examples/multimodal_vision/llava_example.py index 68653f182..e4f779b8e 100644 --- a/examples/multimodal_vision/llava_example.py +++ b/examples/multimodal_vision/llava_example.py @@ -22,7 +22,7 @@ recipe = [ GPTQModifier( targets="Linear", - scheme="W8A8", + scheme="W4A16", ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], sequential_targets=["LlamaDecoderLayer"], ), diff --git a/examples/multimodal_vision/mllama_example.py b/examples/multimodal_vision/mllama_example.py index 20fe316be..b8ee1d23b 100644 --- a/examples/multimodal_vision/mllama_example.py +++ b/examples/multimodal_vision/mllama_example.py @@ -22,7 +22,7 @@ recipe = [ GPTQModifier( targets="Linear", - scheme="W8A8", + scheme="W4A16", ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"], ), ] diff --git a/examples/multimodal_vision/pixtral_example.py b/examples/multimodal_vision/pixtral_example.py index d9a161d01..70226a93f 100644 --- a/examples/multimodal_vision/pixtral_example.py +++ b/examples/multimodal_vision/pixtral_example.py @@ -22,7 +22,7 @@ recipe = [ GPTQModifier( targets="Linear", - scheme="W8A8", + scheme="W4A16", ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"], sequential_targets=["MistralDecoderLayer"], ), From 9abdea8f1a98bfb5f7fea7b4dceae3e05517b096 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 16:18:30 -0500 Subject: [PATCH 258/285] llava: clarify changes, undo style changes Signed-off-by: Kyle Sayers --- .../transformers/tracing/llava.py | 98 +++++++------------ 1 file changed, 34 insertions(+), 64 deletions(-) diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index e77fda2df..0068e7b7d 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -24,19 +24,23 @@ from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaMultiModalProjector, + LlavaPreTrainedModel, logger, ) from transformers.models.mistral.configuration_mistral import MistralConfig from transformers.utils.fx import HFProxy +# TRACING: Reuse tracable subclass from .mistral import MistralForCausalLM as TracableMistralForCausalLM +# TRACING: The shape of image_features is known and documented by +# LlavaForConditionalGeneration.get_image_features def maybe_install_metadata_image_features( image_features: Union[torch.Tensor, HFProxy], pixel_values: Union[torch.Tensor, HFProxy], config: LlavaConfig, -): +) -> Union[torch.Tensor, HFProxy]: if isinstance(image_features, HFProxy): # (num_images, image_length, embed_dim) num_images = pixel_values._metadata.size(0) @@ -52,12 +56,14 @@ def maybe_install_metadata_image_features( return image_features +# TRACING: The shape of inputs_embeds is known. This function compensates for +# the fact that shape inference through `masked_scatter` is not implemented yet def maybe_install_metadata_inputs_embeds( inputs_embeds_masked: Union[torch.Tensor, HFProxy], inputs_embeds: Union[torch.Tensor, HFProxy], special_image_mask: Union[torch.Tensor, HFProxy], image_features: Union[torch.Tensor, HFProxy], -): +) -> Union[torch.Tensor, HFProxy]: if isinstance(inputs_embeds_masked, HFProxy): metadata = inputs_embeds._metadata.masked_scatter( special_image_mask._metadata.to(bool), image_features._metadata @@ -67,27 +73,24 @@ def maybe_install_metadata_inputs_embeds( return inputs_embeds +# TRACING: override `__init__` and `forward` class LlavaForConditionalGeneration(LlavaForConditionalGeneration): def __init__(self, config: LlavaConfig): - super().__init__(config) + super(LlavaPreTrainedModel, self).__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = LlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - # NOT TRACABLE: Must use TracableMistralForCausalLM which wraps untracable function + # TRACING: Must use TracableMistralForCausalLM which wraps an untracable function if isinstance(config.text_config, MistralConfig): self.language_model = TracableMistralForCausalLM(config.text_config) else: self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.pad_token_id = ( - self.config.pad_token_id if self.config.pad_token_id is not None else -1 - ) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() - self.__class__.__name__ = "LlavaForConditionalGeneration" - def forward( self, input_ids: torch.LongTensor = None, @@ -106,23 +109,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy @@ -131,9 +124,7 @@ def forward( ) if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if pixel_values is not None and inputs_embeds is not None: raise ValueError( @@ -148,7 +139,8 @@ def forward( # not very reliable, but we don't expect one to actually pass 500+ images for one prompt # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - # NOT TRACABLE, instead always use legacy_processing = False + # TRACING: Assume that the user will not pass 500+ images for a single prompt + # instead always use legacy_processing = False # legacy_processing = ( # (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length # ) or (input_ids.shape[-1] == 1 and pixel_values is not None) @@ -174,23 +166,17 @@ def forward( ) # prefill stage vs decoding stage (legacy behavior copied) if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = ( - self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - ) - cache_position = torch.arange( - attention_mask.shape[1], device=attention_mask.device + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where( - first_layer_past_key_value.float().sum(-2) == 0 - ) + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) # Get the target length target_length = input_ids.shape[1] @@ -212,20 +198,16 @@ def forward( # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat( - (extended_attention_mask, attention_mask[:, -target_length:]), dim=1 - ) + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange( - attention_mask.shape[1], device=attention_mask.device - )[-target_length:] + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] # @raushan retain only the new behavior after v4.47 elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] * image_features.shape[1] - # NOT TRACABLE, instead always use n_image_tokens != n_image_features = False + # TRACING: Assume that processing and tokenization was done correctly # if n_image_tokens != n_image_features: if False: raise ValueError( @@ -237,16 +219,11 @@ def forward( .expand_as(inputs_embeds) .to(inputs_embeds.device) ) - image_features = image_features.to( - inputs_embeds.device, inputs_embeds.dtype - ) - inputs_embeds_masked = inputs_embeds.masked_scatter( - special_image_mask, image_features - ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds_masked = inputs_embeds.masked_scatter(special_image_mask, image_features) - inputs_embeds_masked = maybe_install_metadata_inputs_embeds( - inputs_embeds_masked, inputs_embeds, special_image_mask, image_features - ) + # TRACING: install metadata + inputs_embeds_masked = maybe_install_metadata_inputs_embeds(inputs_embeds_masked, inputs_embeds, special_image_mask, image_features) inputs_embeds = inputs_embeds_masked outputs = self.language_model( @@ -270,23 +247,16 @@ def forward( if attention_mask is not None: # we use the input attention mask to shift the logits and labels, because it is 2D. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( - logits.device - ) - shift_logits = logits[..., :-1, :][ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = labels[..., 1:][ - shift_attention_mask.to(labels.device) != 0 - ].contiguous() + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1).to(shift_logits.device), + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device), ) if not return_dict: From 3dca7b34cfb9d52fa376b526d68e2e168d1acfe3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 17:03:16 -0500 Subject: [PATCH 259/285] glm comments, fix isort Signed-off-by: Kyle Sayers --- pyproject.toml | 3 +- .../transformers/tracing/__init__.py | 6 ++- .../tracing/glm/modeling_chatglm.py | 37 ++++++++++--------- .../transformers/tracing/glm/visual.py | 9 ++--- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 90164e056..e9cd799bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,13 @@ target-version = ['py38'] [tool.isort] profile = "black" +skip = ["src/llmcompressor/transformers/tracing/"] [tool.mypy] files = "src/guidellm" [tool.ruff] -exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/transformers/tracing"] +exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/transformers/tracing/"] lint.select = ["E", "F", "W"] [tool.flake8] diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 4d846c536..bb89a8ddf 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,4 +1,6 @@ -from .glm.modeling_chatglm import ChatGLMForConditionalGeneration +from .glm.modeling_chatglm import ( + ChatGLMForConditionalGeneration as TracableChatGLMForConditionalGeneration, +) from .llava import ( LlavaForConditionalGeneration as TracableLlavaForConditionalGeneration, ) @@ -11,5 +13,5 @@ "TracableLlavaForConditionalGeneration", "TracableMllamaForConditionalGeneration", "TracableMistralForCausalLM", - "ChatGLMForConditionalGeneration", + "TracableChatGLMForConditionalGeneration", ] diff --git a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py index 45fcff21b..db40f3331 100644 --- a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py +++ b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py @@ -3,31 +3,29 @@ """ PyTorch GLM-4V model. """ import math import sys -from typing import Any, Dict, List, Optional, Tuple, Union - import torch -import torch.nn.functional as F import torch.utils.checkpoint +import torch.nn.functional as F from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss from torch.nn.utils import skip_init -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import ( - GenerationConfig, - LogitsProcessorList, - ModelOutput, - StoppingCriteriaList, -) +from typing import Optional, Tuple, Union, List, Dict, Any + from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel -from transformers.utils import is_torch_npu_available, logging +from transformers.utils import logging, is_torch_npu_available +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput -from .configuration_chatglm import ChatGLMConfig from .visual import EVA2CLIPModel +from .configuration_chatglm import ChatGLMConfig + +# TRACING: import wrap +from torch.fx import wrap try: from transformers.utils import ( @@ -45,8 +43,6 @@ except: pass -from torch.fx import wrap - # flags required to enable jit fusion kernels if sys.platform != 'darwin' and not is_torch_npu_available(): @@ -890,6 +886,7 @@ def forward(self, input_ids): return embeddings +# TRACING: this function is untracable @wrap def is_empty(images_list: Optional[List[List[torch.Tensor]]]): if images_list is None or len(images_list) == 0: @@ -997,6 +994,7 @@ def forward( input_id = input_ids[i].tolist() boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( self.config.eoi_token_id) + # TRACING: Assume that processing and tokenization was done correctly #assert eoi_token_pos - boi_token_pos == 2 new_input_embeds.append(torch.cat( (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device), @@ -1028,13 +1026,16 @@ def forward( attention_mask], dim=-1) if full_attention_mask is None: - if True: #if (attention_mask is not None and not attention_mask.all().item()) or (past_key_values and seq_length != 1): - if True: #if self.training: + # TRACING: Assume only prefill and that the attention mask is full + #if (attention_mask is not None and not attention_mask.all().item()) or (past_key_values and seq_length != 1): + if False: #if (attention_mask is not None and not attention_mask.all().item()) or (past_key_values and seq_length != 1): + if self.training: # https://github.com/THUDM/GLM-4/issues/264 new_input_ids, new_attention_mask = [], [] for i in range(len(input_ids)): input_id = input_ids[i].tolist() boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(self.config.eoi_token_id) + # TRACING: Assume that processing and tokenization was done correctly #assert eoi_token_pos - boi_token_pos == 2 new_attention_mask.append(torch.cat( @@ -1152,6 +1153,7 @@ def prepare_inputs_for_generation( if not is_empty(images): boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( self.config.eoi_token_id) + # TRACING: Assume that processing and tokenization was done correctly #assert eoi_token_pos - boi_token_pos == 2 new_attention_masks.append(torch.cat( (attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches), @@ -1214,6 +1216,7 @@ def forward( input_id = input_ids[i].tolist() boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( self.config.eoi_token_id) + # TRACING: Assume that processing and tokenization was done correctly #assert eoi_token_pos - boi_token_pos == 2 new_labels.append(torch.cat( diff --git a/src/llmcompressor/transformers/tracing/glm/visual.py b/src/llmcompressor/transformers/tracing/glm/visual.py index 1c2dd7fc6..1a7792e57 100644 --- a/src/llmcompressor/transformers/tracing/glm/visual.py +++ b/src/llmcompressor/transformers/tracing/glm/visual.py @@ -1,13 +1,12 @@ # flake8: noqa # vllm-project: no copyright -import math -from argparse import Namespace - import torch -import torch.nn.functional as F from torch import nn -from torch.nn import LayerNorm +from argparse import Namespace +import torch.nn.functional as F from transformers.activations import ACT2FN +import math +from torch.nn import LayerNorm def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True): From f416674ac26235dfc9d1476f963cff55386faf9c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 17:06:41 -0500 Subject: [PATCH 260/285] correct typo 'tracable' Signed-off-by: Kyle Sayers --- examples/multimodal_vision/llava_example.py | 4 +- examples/multimodal_vision/mllama_example.py | 4 +- examples/multimodal_vision/pixtral_example.py | 4 +- .../quantization_w4a16/minicpm_example.py | 82 +++++++++++++++++ .../quantization_w4a16/mixtral_example.py | 89 +++++++++++++++++++ .../pipelines/sequential/pipeline.py | 4 +- .../transformers/tracing/__init__.py | 16 ++-- .../tracing/glm/modeling_chatglm.py | 2 +- .../transformers/tracing/llava.py | 8 +- .../transformers/tracing/mllama.py | 2 +- 10 files changed, 193 insertions(+), 22 deletions(-) create mode 100644 examples/quantization_w4a16/minicpm_example.py create mode 100644 examples/quantization_w4a16/mixtral_example.py diff --git a/examples/multimodal_vision/llava_example.py b/examples/multimodal_vision/llava_example.py index e4f779b8e..c86cf0dfe 100644 --- a/examples/multimodal_vision/llava_example.py +++ b/examples/multimodal_vision/llava_example.py @@ -2,12 +2,12 @@ from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot -from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration +from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration from llmcompressor.transformers.utils.data_collator import llava_data_collator # Load model. model_id = "llava-hf/llava-1.5-7b-hf" -model = TracableLlavaForConditionalGeneration.from_pretrained( +model = TraceableLlavaForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype="auto" ) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) diff --git a/examples/multimodal_vision/mllama_example.py b/examples/multimodal_vision/mllama_example.py index b8ee1d23b..16c17f18e 100644 --- a/examples/multimodal_vision/mllama_example.py +++ b/examples/multimodal_vision/mllama_example.py @@ -2,12 +2,12 @@ from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot -from llmcompressor.transformers.tracing import TracableMllamaForConditionalGeneration +from llmcompressor.transformers.tracing import TraceableMllamaForConditionalGeneration from llmcompressor.transformers.utils.data_collator import mllama_data_collator # Load model. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" -model = TracableMllamaForConditionalGeneration.from_pretrained( +model = TraceableMllamaForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype="auto" ) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) diff --git a/examples/multimodal_vision/pixtral_example.py b/examples/multimodal_vision/pixtral_example.py index 70226a93f..e068a6dc9 100644 --- a/examples/multimodal_vision/pixtral_example.py +++ b/examples/multimodal_vision/pixtral_example.py @@ -2,12 +2,12 @@ from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot -from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration +from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration from llmcompressor.transformers.utils.data_collator import pixtral_data_collator # Load model. model_id = "mgoin/pixtral-12b" -model = TracableLlavaForConditionalGeneration.from_pretrained( +model = TraceableLlavaForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype="auto" ) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) diff --git a/examples/quantization_w4a16/minicpm_example.py b/examples/quantization_w4a16/minicpm_example.py new file mode 100644 index 000000000..826d01d34 --- /dev/null +++ b/examples/quantization_w4a16/minicpm_example.py @@ -0,0 +1,82 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot + +# Select model and load it. +MODEL_ID = "openbmb/MiniCPM3-4B" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="auto", + torch_dtype="auto", + trust_remote_code=True, +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) + +# Apply algorithms. +oneshot( + model=model, + tokenizer=tokenizer, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") diff --git a/examples/quantization_w4a16/mixtral_example.py b/examples/quantization_w4a16/mixtral_example.py new file mode 100644 index 000000000..10834725a --- /dev/null +++ b/examples/quantization_w4a16/mixtral_example.py @@ -0,0 +1,89 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.transformers.compression.helpers import calculate_offload_device_map + +# Select model and load it. +MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" +device_map = calculate_offload_device_map( + MODEL_ID, + reserve_for_hessians=True, + num_gpus=1, + torch_dtype="auto", +) + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map=device_map, + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*block_sparse_moe.gate"] +) + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index ca7caf7a0..f3e0a762f 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -31,12 +31,12 @@ def run_pipeline( 3. The intermediate activations between each subgraph are cached and offloaded to the cpu between each batch in order to save memory - This pipeline requires that the model be tracable with respect to data from the + This pipeline requires that the model be traceable with respect to data from the data loader. This may be an issue for vision language models with vision datasets, due to specialized input processing in the model. In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A model - can be made tracable by wrapping the untracable functions (see + can be made traceable by wrapping the untraceable functions (see llmcompressor.transformers.tracing) """ # trace subgraphs diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index bb89a8ddf..19abc1a9d 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,17 +1,17 @@ from .glm.modeling_chatglm import ( - ChatGLMForConditionalGeneration as TracableChatGLMForConditionalGeneration, + ChatGLMForConditionalGeneration as TraceableChatGLMForConditionalGeneration, ) from .llava import ( - LlavaForConditionalGeneration as TracableLlavaForConditionalGeneration, + LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration, ) -from .mistral import MistralForCausalLM as TracableMistralForCausalLM +from .mistral import MistralForCausalLM as TraceableMistralForCausalLM from .mllama import ( - MllamaForConditionalGeneration as TracableMllamaForConditionalGeneration, + MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration, ) __all__ = [ - "TracableLlavaForConditionalGeneration", - "TracableMllamaForConditionalGeneration", - "TracableMistralForCausalLM", - "TracableChatGLMForConditionalGeneration", + "TraceableLlavaForConditionalGeneration", + "TraceableMllamaForConditionalGeneration", + "TraceableMistralForCausalLM", + "TraceableChatGLMForConditionalGeneration", ] diff --git a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py index db40f3331..c59b6a0a1 100644 --- a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py +++ b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py @@ -886,7 +886,7 @@ def forward(self, input_ids): return embeddings -# TRACING: this function is untracable +# TRACING: this function is untraceable @wrap def is_empty(images_list: Optional[List[List[torch.Tensor]]]): if images_list is None or len(images_list) == 0: diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index 0068e7b7d..0f993a356 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -30,8 +30,8 @@ from transformers.models.mistral.configuration_mistral import MistralConfig from transformers.utils.fx import HFProxy -# TRACING: Reuse tracable subclass -from .mistral import MistralForCausalLM as TracableMistralForCausalLM +# TRACING: Reuse traceable subclass +from .mistral import MistralForCausalLM as TraceableMistralForCausalLM # TRACING: The shape of image_features is known and documented by @@ -82,9 +82,9 @@ def __init__(self, config: LlavaConfig): self.multi_modal_projector = LlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - # TRACING: Must use TracableMistralForCausalLM which wraps an untracable function + # TRACING: Must use TraceableMistralForCausalLM which wraps an untraceable function if isinstance(config.text_config, MistralConfig): - self.language_model = TracableMistralForCausalLM(config.text_config) + self.language_model = TraceableMistralForCausalLM(config.text_config) else: self.language_model = AutoModelForCausalLM.from_config(config.text_config) diff --git a/src/llmcompressor/transformers/tracing/mllama.py b/src/llmcompressor/transformers/tracing/mllama.py index 512ba4227..3ff3aa94d 100644 --- a/src/llmcompressor/transformers/tracing/mllama.py +++ b/src/llmcompressor/transformers/tracing/mllama.py @@ -50,7 +50,7 @@ logger = logging.get_logger(__name__) -@wrap # NOT TRACABLE, wrap this function +@wrap # NOT TRACEABLE, wrap this function def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, num_vision_tokens: int, From 71faee72825c336fa73a67602bda7ad8a7ccffdd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 17:38:00 -0500 Subject: [PATCH 261/285] mllama: remove unnecessary definitions Signed-off-by: Kyle Sayers --- .../transformers/tracing/mllama.py | 2506 +---------------- 1 file changed, 50 insertions(+), 2456 deletions(-) diff --git a/src/llmcompressor/transformers/tracing/mllama.py b/src/llmcompressor/transformers/tracing/mllama.py index 3ff3aa94d..395567afd 100644 --- a/src/llmcompressor/transformers/tracing/mllama.py +++ b/src/llmcompressor/transformers/tracing/mllama.py @@ -16,2318 +16,63 @@ # vllm-project: no copyright """PyTorch Mllama model.""" -import math from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint -from torch import nn -from torch.fx import wrap -from transformers import PreTrainedModel -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.generation import GenerationMixin -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.models.mllama.configuration_mllama import ( - MllamaConfig, - MllamaTextConfig, - MllamaVisionConfig, -) -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) - -logger = logging.get_logger(__name__) - - -@wrap # NOT TRACEABLE, wrap this function -def _prepare_cross_attention_mask( - cross_attention_mask: torch.Tensor, - num_vision_tokens: int, - dtype: str, -) -> Tuple[torch.Tensor, torch.Tensor]: - # reshape so it can be used by attn module - batch_size, text_total_length, *_ = cross_attention_mask.shape - cross_attention_mask = cross_attention_mask.repeat_interleave( - num_vision_tokens, dim=3 - ) - cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) - cross_attention_mask = cross_attention_mask.unsqueeze(1) - - # invert the mask - inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) - cross_attention_mask = inverted_cross_attn_mask.masked_fill( - inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min - ) - - # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.finfo(dtype).min - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value) - .any(dim=-1) - .type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - - return cross_attention_mask, full_text_row_masked_out_mask - - -def _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask: torch.Tensor, - num_patches: int, - target_length: int, - dtype: torch.dtype, -) -> torch.Tensor: - # Expand aspect ratio mask to target_length - batch_size, max_num_tiles = aspect_ratio_mask.shape - attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) - attention_mask = attention_mask.repeat(1, 1, target_length, 1) - - # Mask padding patches - pad_patches = target_length - num_patches - attention_mask[:, :, -pad_patches:] = 0 - - # Invert the mask (0 -> 1, 1 -> 0) - attention_mask = 1 - attention_mask - - # Reshape to 2D and create 4D attention mask - # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) - attention_mask = attention_mask.reshape( - batch_size, max_num_tiles * target_length, 1 - ) - attention_mask = ( - attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min - ) - attention_mask = attention_mask.unsqueeze(1) - - return attention_mask - - -class MllamaPrecomputedAspectRatioEmbedding(nn.Module): - def __init__(self, config: MllamaVisionConfig, is_gated: bool = True): - super().__init__() - self.max_num_tiles = config.max_num_tiles - self.hidden_size = config.hidden_size - self.max_aspect_ratio_id = config.max_aspect_ratio_id - self.is_gated = is_gated - - self.embedding = nn.Embedding( - self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size - ) - if is_gated: - self.gate = nn.Parameter(torch.zeros(1)) - - def forward( - self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor - ) -> torch.Tensor: - embeddings = self.embedding(aspect_ratio_ids) - embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) - - if self.is_gated: - embeddings = embeddings * self.gate.tanh() - - hidden_state = hidden_state + embeddings - return hidden_state - - -class MllamaPrecomputedPositionEmbedding(nn.Module): - def __init__(self, config: MllamaVisionConfig): - super().__init__() - self.max_num_tiles = config.max_num_tiles - self.max_aspect_ratio_id = config.max_aspect_ratio_id - self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 - self.hidden_size = config.hidden_size - self.scale = config.hidden_size**-0.5 - - self.gate = nn.Parameter(torch.zeros(1)) - - # position embedding - position_embedding = torch.randn(self.num_patches, self.hidden_size) - self.embedding = nn.Parameter(self.scale * position_embedding) - - # tile position embedding - self.tile_embedding = nn.Embedding( - self.max_aspect_ratio_id + 1, - self.max_num_tiles * self.num_patches * self.hidden_size, - ) - - def forward( - self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor - ) -> torch.Tensor: - # position embeddings - gated_position_embedding = (1 - self.gate.tanh()) * self.embedding - hidden_state = hidden_state + gated_position_embedding.view( - 1, 1, self.num_patches, self.hidden_size - ) - - # precomputed tile position embeddings - tile_position_embedding = self.tile_embedding(aspect_ratio_ids) - batch_size = hidden_state.shape[0] - tile_position_embedding = tile_position_embedding.reshape( - batch_size, self.max_num_tiles, self.num_patches, self.hidden_size - ) - gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding - hidden_state = hidden_state + gated_tile_position_embedding - - return hidden_state - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision -class MllamaVisionMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class MllamaVisionAttention(nn.Module): - def __init__(self, config: MllamaVisionConfig): - super().__init__() - - self.embed_dim = config.hidden_size - self.num_heads = config.attention_heads - self.head_dim = config.hidden_size // config.attention_heads - - self.q_proj = nn.Linear( - self.embed_dim, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.embed_dim, self.num_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.embed_dim, self.num_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.embed_dim, bias=False - ) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = None, - ) -> torch.Tensor: - query = self.q_proj(hidden_state) - key = self.k_proj(hidden_state) - value = self.v_proj(hidden_state) - - batch_size, q_seq_len, _ = query.shape - _, kv_seq_len, _ = key.shape - - query = query.view( - batch_size, q_seq_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose( - 1, 2 - ) - value = value.view( - batch_size, kv_seq_len, self.num_heads, self.head_dim - ).transpose(1, 2) - - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query.dtype) - attn_output = torch.matmul(attn_weights, value) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_seq_len, -1) - - output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return output, attn_weights - - -class MllamaVisionSdpaAttention(MllamaVisionAttention): - # Adapted from MllamaVisionAttention - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = None, - ) -> torch.Tensor: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - if output_attentions: - logger.warning_once( - "MllamaModel is using MllamaVisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_state=hidden_state, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - query = self.q_proj(hidden_state) - key = self.k_proj(hidden_state) - value = self.v_proj(hidden_state) - - batch_size, q_seq_len, _ = query.shape - _, kv_seq_len, _ = key.shape - - query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) - key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) - value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_seq_len, -1) - - output = self.o_proj(attn_output) - - return output, None - - -MLLAMA_VISION_ATTENTION_CLASSES = { - "eager": MllamaVisionAttention, - "sdpa": MllamaVisionSdpaAttention, -} - - -class MllamaVisionEncoderLayer(nn.Module): - def __init__(self, config: MllamaVisionConfig, is_gated: bool = False): - super().__init__() - - self.hidden_size = config.hidden_size - self.num_attention_heads = config.attention_heads - self.is_gated = is_gated - self.intermediate_size = config.intermediate_size - - self.self_attn = MLLAMA_VISION_ATTENTION_CLASSES[config._attn_implementation]( - config - ) - self.mlp = MllamaVisionMLP(config) - - self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = nn.LayerNorm( - self.hidden_size, eps=config.norm_eps - ) - - if is_gated: - self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) - self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = None, - ): - # Self Attention - residual = hidden_state - hidden_state = self.input_layernorm(hidden_state) - hidden_state, attn_weights = self.self_attn( - hidden_state, attention_mask=attention_mask - ) - if self.is_gated: - hidden_state = self.gate_attn.tanh() * hidden_state - hidden_state = residual + hidden_state - - # Feed forward - residual = hidden_state - hidden_state = self.post_attention_layernorm(hidden_state) - hidden_state = self.mlp(hidden_state) - if self.is_gated: - hidden_state = self.gate_ffn.tanh() * hidden_state - hidden_state = residual + hidden_state - - outputs = (hidden_state,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class MllamaVisionEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`MllamaEncoderLayer`]. - - Args: - config: MllamaConfig - """ - - def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False): - super().__init__() - self.config = config - self.layers = nn.ModuleList( - [MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)] - ) - self.gradient_checkpointing = False - self.config = config - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_state=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - hidden_states = layer_outputs[0] - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, encoder_states, all_attentions] - if v is not None - ) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText -class MllamaTextRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - MllamaTextRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class MllamaTextCrossAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: Optional[MllamaTextConfig] = None, - layer_idx: Optional[int] = None, - ): - super().__init__() - self.config = config - self.num_heads = self.config.num_attention_heads - self.num_key_value_heads = self.config.num_key_value_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.head_dim = config.hidden_size // self.num_heads - self.layer_idx = layer_idx - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - query_states = self.q_norm(query_states) - - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) - if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( - key_states, - value_states, - self.layer_idx, - {"cache_position": cache_position}, - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MllamaTextCrossSdpaAttention(MllamaTextCrossAttention): - """ - Mllama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MllamaTextCrossAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MllamaTextCrossAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MllamaModel is using MllamaTextCrossSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - cross_attention_states=cross_attention_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - query_states = self.q_norm(query_states) - - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( - key_states, - value_states, - self.layer_idx, - {"cache_position": cache_position}, - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if attention_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class MllamaTextSelfAttention(nn.Module): - def __init__(self, config: MllamaTextConfig, layer_idx: int): - super().__init__() - self.config = config - self.num_heads = config.num_attention_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.layer_idx = layer_idx - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MllamaTextSelfSdpaAttention(MllamaTextSelfAttention): - # Adapted from MllamaTextSelfAttention - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, - ): - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MllamaModel is using MllamaTextSelfSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_embeddings=position_embeddings, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value - - -MLLAMA_TEXT_CROSS_ATTENTION_CLASSES = { - "eager": MllamaTextCrossAttention, - "sdpa": MllamaTextCrossSdpaAttention, -} -MLLAMA_TEXT_ATTENTION_CLASSES = { - "eager": MllamaTextSelfAttention, - "sdpa": MllamaTextSelfSdpaAttention, -} - - -# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText -class MllamaTextMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - # Ignore copy - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer -class MllamaSelfAttentionDecoderLayer(nn.Module): - def __init__(self, config: MllamaTextConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = MLLAMA_TEXT_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx - ) - - self.mlp = MllamaTextMLP(config) - self.input_layernorm = MllamaTextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = MllamaTextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - self.layer_idx = layer_idx - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # will become mandatory in v4.45 - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class MllamaCrossAttentionDecoderLayer(torch.nn.Module): - """Cross-attention transformer block with tanh-gated attention and feedforward.""" - - def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: - super().__init__() - self.layer_idx = layer_idx - self.cross_attn = MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ - config._attn_implementation - ](config, layer_idx=layer_idx) - - self.input_layernorm = MllamaTextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) - - self.mlp = MllamaTextMLP(config) - self.post_attention_layernorm = MllamaTextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: torch.Tensor, - cross_attention_mask: torch.Tensor, - attention_mask: torch.Tensor, - full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states, attn_weights, past_key_value = self.cross_attn( - hidden_states=hidden_states, - attention_mask=cross_attention_mask, - cross_attention_states=cross_attention_states, - past_key_value=past_key_value, - output_attentions=output_attentions, - cache_position=cache_position, - ) - hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if full_text_row_masked_out_mask is not None: - hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore - hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - if use_cache: - outputs += (past_key_value,) - - return outputs - - -class MllamaRotaryEmbedding(nn.Module): - def __init__(self, config: MllamaTextConfig, device=None): - super().__init__() - self.rope_type = config.rope_scaling["rope_type"] - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer( - "inv_freq", inv_freq, persistent=False - ) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if ( - seq_len < self.original_max_seq_len - and self.max_seq_len_cached > self.original_max_seq_len - ): # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - freqs = ( - inv_freq_expanded.float() @ position_ids_expanded.float() - ).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class MllamaPreTrainedModel(PreTrainedModel): - config_class = MllamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = [ - "MllamaVisionEncoderLayer", - "MllamaCrossAttentionDecoderLayer", - "MllamaSelfAttentionDecoderLayer", - ] - _supports_cache_class = True - _supports_static_cache = ( - False # static cache cannot have different shapes for each layer - ) - _supports_sdpa = True - _supports_quantized_cache = True - - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.Parameter): - module.data.normal_(mean=0.0, std=std) - elif isinstance(module, MllamaVisionModel): - nn.init.normal_(module.class_embedding.data, std=std) - elif isinstance(module, MllamaPrecomputedPositionEmbedding): - nn.init.normal_(module.embedding.data, std=std) - elif isinstance(module, MllamaVisionEncoderLayer) and module.is_gated: - nn.init.normal_(module.gate_attn.data, std=std) - nn.init.normal_(module.gate_ffn.data, std=std) - - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not using_static_cache - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) - - return causal_mask - - @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device, - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange( - target_length, device=device - ) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = ( - causal_mask.clone() - ) # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = ( - causal_mask[:, :, :, :mask_length] - + attention_mask[:, None, None, :] - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[ - :, :, :, :mask_length - ].masked_fill(padding_mask, min_dtype) - - return causal_mask - - -MLLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`MllamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -MLLAMA_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): - The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`MllamaImageProcessor.__call__`] for details ([]`MllamaProcessor`] uses - [`MllamaImageProcessor`] for processing images). - aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): - Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`: - - - 1 for tiles that are **not masked**, - - 0 for tiles that are **masked**. - aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): - Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image. - These ids correspond to indices in the model's list of supported aspect ratios, offset by 1. - - For example, if the model supports aspect ratios [[1, 1], [1, 2], [2, 1]]: - - An image with aspect ratio [1, 1] would have ID 1 - - An image with aspect ratio [1, 2] would have ID 2 - - An image with aspect ratio [2, 1] would have ID 3 - - The id 0 is reserved for padding (i.e., no image). - - If an image has aspect ratio [1, 2], that means it was split into 2 tiles horizontally, and its `aspect_ratio_id` would be 2. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -MLLAMA_TEXT_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): - Cross-attention mask to control the interaction between text tokens and image tiles. - This 4D tensor defines which image tiles each text token should attend to. - - For each text token (in seq_length): - - 1 indicates the token **should attend** to the corresponding image tile - - 0 indicates the token **should not attend** to the corresponding image tile - cross_attention_states (`torch.FloatTensor`, *optional*): - Output of the vision model, used for cross-attention. This tensor contains the processed image features that - the language model will attend to. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -MLLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): - The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`MllamaImageProcessor.__call__`] for details ([]`MllamaProcessor`] uses - [`MllamaImageProcessor`] for processing images). - aspect_ratio_mask (`torch.Tensor` of shape `(batch_size, max_num_images, max_num_tiles)`, *optional*): - Mask to avoid performing attention on padding tiles. Mask values selected in `[0, 1]`: - - - 1 for tiles that are **not masked**, - - 0 for tiles that are **masked**. - aspect_ratio_ids (`torch.Tensor` of shape `(batch_size, max_num_images)`, *optional*): - Aspect ratio ids used to select the appropriate precomputed tile embeddings based on the aspect ratio of each input image. - These ids correspond to indices in the model's list of supported aspect ratios, offset by 1. - - For example, if the model supports aspect ratios [[1, 1], [1, 2], [2, 1]]: - - An image with aspect ratio [1, 1] would have ID 1 - - An image with aspect ratio [1, 2] would have ID 2 - - An image with aspect ratio [2, 1] would have ID 3 - - The id 0 is reserved for padding (i.e., no image). - - If an image has aspect ratio [1, 2], that means it was split into 2 tiles horizontally, and its `aspect_ratio_id` would be 2. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*): - Cross-attention mask to control the interaction between text tokens and image tiles. - This 4D tensor defines which image tiles each text token should attend to. - - For each text token (in seq_length): - - 1 indicates the token **should attend** to the corresponding image tile - - 0 indicates the token **should not attend** to the corresponding image tile - cross_attention_states (`torch.FloatTensor`, *optional*): - Output of the vision model, used for cross-attention. This tensor contains the processed image features that - the language model will attend to. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - """The Mllama Vision Model which consists of two vision encoders.""", - MLLAMA_START_DOCSTRING, -) -class MllamaVisionModel(MllamaPreTrainedModel): - config_class = MllamaVisionConfig - base_model_prefix = "vision_model" - - def __init__(self, config: MllamaVisionConfig): - super().__init__(config) - self.image_size = config.image_size - self.patch_size = config.patch_size - self.max_num_tiles = config.max_num_tiles - self.hidden_size = config.hidden_size - self.num_channels = config.num_channels - self.intermediate_layers_indices = config.intermediate_layers_indices - - self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 - self.scale = config.hidden_size**-0.5 - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.hidden_size, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - bias=False, - ) - - self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) - self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) - - self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( - config, is_gated=True - ) - self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( - config, is_gated=True - ) - - # layer norms - self.layernorm_pre = nn.LayerNorm(self.hidden_size) - self.layernorm_post = nn.LayerNorm(self.hidden_size) - - # encoders - self.transformer = MllamaVisionEncoder( - config, config.num_hidden_layers, is_gated=False - ) - self.global_transformer = MllamaVisionEncoder( - config, config.num_global_layers, is_gated=True - ) - self.post_init() - - def get_input_embeddings(self): - """ - This function is used to fetch the first embedding layer to activate grads on inputs. - """ - return self.patch_embedding - - def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: - batch_size, _, hidden_size = hidden_state.shape - class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) - hidden_state = torch.cat([class_embedding, hidden_state], dim=1) - return hidden_state - - @add_start_docstrings_to_model_forward(MLLAMA_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=BaseModelOutput, config_class="MllamaVisionConfig" - ) - def forward( - self, - pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - aspect_ratio_mask: torch.Tensor, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: - r""" - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, MllamaVisionModel - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaVisionModel.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(images=image, return_tensors="pt") - - >>> output = model(**inputs) - - >>> print(output.last_hidden_state.shape) - torch.Size([1, 1, 4, 1025, 7680]) - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( - pixel_values.shape - ) - - pixel_values = pixel_values.reshape( - batch_size * num_concurrent_media * num_tiles, num_channels, height, width - ) - aspect_ratio_ids = aspect_ratio_ids.reshape( - batch_size * num_concurrent_media, -1 - ) - - # Patch embedding - patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) - hidden_state = patch_embeds.flatten(2).transpose(1, 2) - - # Tile embeddings - _, num_patches, dim = hidden_state.shape - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles, -1, dim - ) - hidden_state = self.pre_tile_positional_embedding( - hidden_state, aspect_ratio_ids - ) - - # Add cls token - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media * num_tiles, num_patches, dim - ) - hidden_state = self.apply_class_embedding(hidden_state) - num_patches += 1 - - # Position embeddings - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles, num_patches, dim - ) - hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) - - hidden_state = self.layernorm_pre(hidden_state) - - # Compute the number of tokens to pad - num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 - # Compute padding tuple for pad function - padding = ( - 0, - 0, - 0, - num_padding_patches, - ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) - # Pad the tensor - hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) - slice_index = -num_padding_patches if num_padding_patches > 0 else None - - # Prepare attention mask - attention_mask = aspect_ratio_mask.reshape( - batch_size * num_concurrent_media, -1 - ) - attention_mask = _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask=attention_mask, - num_patches=self.num_patches, - target_length=hidden_state.shape[2], - dtype=self.dtype, - ) - - # Apply encoder - hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) - output = self.transformer( - hidden_state, - attention_mask=attention_mask, - output_hidden_states=True, - output_attentions=output_attentions, - ) - hidden_state = output[0] - - hidden_state = self.layernorm_post(hidden_state) - - # Apply global encoder - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim, - ) - hidden_state = self.post_tile_positional_embedding( - hidden_state, aspect_ratio_ids - ) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles * (num_patches + num_padding_patches), - dim, - ) - global_output = self.global_transformer( - hidden_state, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - ) - hidden_state = global_output[0] - - # Remove padding form hidden state - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim, - ) - hidden_state = hidden_state[:, :, :slice_index] - hidden_state = hidden_state.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, dim - ) - - # Collect intermediate layer outputs from encoder output - all_intermediate_hidden_states = [ - output[1][i] for i in self.intermediate_layers_indices - ] - intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) - - # Remove padding from intermediate hidden states - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - -1, - ) - intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, -1 - ) - - # Concatenate final hidden state and intermediate hidden states - hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) - - if output_hidden_states: - hidden_states = tuple(all_intermediate_hidden_states) + tuple( - global_output[1] - ) - else: - hidden_states = None - - if output_attentions: - # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range - global_attn = ( - tuple(global_output[2]) - if output_hidden_states - else tuple(global_output[1]) - ) - attentions = tuple(output[2]) + global_attn - else: - attentions = None - - if not return_dict: - return tuple( - v for v in [hidden_state, hidden_states, attentions] if v is not None - ) - - return BaseModelOutput( - last_hidden_state=hidden_state, - hidden_states=hidden_states, - attentions=attentions, - ) - - -@add_start_docstrings( - """The Mllama Text Model which consists of transformer with self and cross attention layers.""", - MLLAMA_START_DOCSTRING, +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import ( + add_start_docstrings, + logging, ) -class MllamaTextModel(MllamaPreTrainedModel): - config_class = MllamaTextConfig - base_model_prefix = "language_model.model" - - def __init__(self, config: MllamaTextConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding( - config.vocab_size + 8, config.hidden_size, self.padding_idx - ) - self.cross_attention_layers = config.cross_attention_layers - - layers = [] - for layer_idx in range(config.num_hidden_layers): - if layer_idx in self.cross_attention_layers: - layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx)) - else: - layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx)) - - self.layers = nn.ModuleList(layers) - self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = MllamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(MLLAMA_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=BaseModelOutputWithPast, config_class="MllamaTextConfig" - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.FloatTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """ - - Returns: - - Example: - - ```python - >>> from transformers import AutoProcessor, MllamaTextModel - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaTextModel.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> text = "<|image|>If I had to write a haiku for this one" - >>> inputs = processor(text=text, return_tensors="pt") - - >>> output = model(**inputs) - - >>> print(output.last_hidden_state.shape) - torch.Size([1, 13, 4096]) - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - hidden_states = inputs_embeds - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - output_attentions, - ) - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # For text-only path we should skip cross attention layers. - # Let's check if the layer is cross attention layer and if we have cross attention states - # or cached cross attention states. - is_cross_attention_layer = idx in self.cross_attention_layers - is_cross_attention_cache_empty = past_key_values is None or ( - past_key_values is not None and past_key_values.get_seq_length(idx) == 0 - ) - - if ( - is_cross_attention_layer - and cross_attention_states is None - and is_cross_attention_cache_empty - ): - continue - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - cross_attention_states, - cross_attention_mask, - causal_mask, - full_text_row_masked_out_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - attention_mask=causal_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -@add_start_docstrings( - """The Mllama Text Model with a language modeling head on top.""", +# TRACING: imports +from torch.fx import wrap +from transformers.models.mllama.modeling_mllama import ( MLLAMA_START_DOCSTRING, + MllamaForConditionalGeneration, ) -class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): - config_class = MllamaTextConfig - _supports_static_cache = True # only the LLM without cross attn can do compile - base_model_prefix = "language_model" - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config.get_text_config()) - self.text_config = config.get_text_config() - self.vocab_size = self.text_config.vocab_size - self.model = MllamaTextModel._from_config(self.text_config) - self.lm_head = nn.Linear( - self.text_config.hidden_size, self.vocab_size, bias=False - ) - - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings +logger = logging.get_logger(__name__) - def set_decoder(self, decoder): - self.model = decoder - def get_decoder(self): - return self.model +# TRACING: This function is not traceable +@wrap +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) - @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MllamaForCausalLM - - >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") - >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") - - >>> prompt = "If I had to write a haiku, it would be:" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) - >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - >>> print(result) - If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. - I love the idea of snowflakes gently falling, each one - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - cross_attention_states=cross_attention_states, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + return cross_attention_mask, full_text_row_masked_out_mask +# TRACING: needs to use wrapped _prepare_cross_attention_mask @add_start_docstrings( """The Mllama model which consists of a vision encoder and a language model.""", MLLAMA_START_DOCSTRING, ) -class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): - _supports_quantized_cache = ( - False # quant cache not supported in encoder-decoder setting - ) - - def __init__(self, config: MllamaConfig): - super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.hidden_size = config.text_config.hidden_size - self.max_num_tiles = config.vision_config.max_num_tiles - self.vision_output_dim = config.vision_config.vision_output_dim - self.pad_token_id = ( - self.config.pad_token_id if self.config.pad_token_id is not None else -1 - ) - - self.vision_model = MllamaVisionModel._from_config(config.vision_config) - self.language_model = MllamaForCausalLM._from_config(config.text_config) - self.multi_modal_projector = nn.Linear( - config.vision_config.vision_output_dim, - config.text_config.hidden_size, - bias=True, - ) - self.post_init() - - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - - def tie_weights(self): - return self.language_model.tie_weights() - - @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class="MllamaConfig" - ) +class MllamaForConditionalGeneration(MllamaForConditionalGeneration): def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -2348,66 +93,14 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, MllamaForConditionalGeneration - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> prompt = "<|image|>If I had to write a haiku for this one" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") - - >>> # Generate - >>> output = model.generate(**inputs, max_new_tokens=15) - - >>> prompt_len = inputs.input_ids.shape[-1] - >>> generated_ids = output[:, prompt_len:] - >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - >>> print(generated_text) - [', it would be:.\\nA stop sign in Chinatown.\\n'] - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if pixel_values is not None and inputs_embeds is not None: raise ValueError( @@ -2415,15 +108,11 @@ def forward( ) if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") if pixel_values is not None: if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") # get vision tokens from vision model vision_outputs = self.vision_model( pixel_values=pixel_values, @@ -2434,26 +123,23 @@ def forward( return_dict=return_dict, ) cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector( - cross_attention_states - ).reshape(-1, cross_attention_states.shape[-2], self.hidden_size) + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = ( - _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - ) + # TRACING: use wrapped function + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, ) else: full_text_row_masked_out_mask = None if cross_attention_mask is not None and cache_position is not None: cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, cache_position - ] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] outputs = self.language_model( input_ids=input_ids, @@ -2474,95 +160,3 @@ def forward( ) return outputs - - def prepare_inputs_for_generation( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - pixel_values=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, - past_key_values=None, - use_cache=False, - cache_position=None, - num_logits_to_keep=None, - **kwargs, - ): - # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif ( - input_ids.shape[1] != cache_position.shape[0] - ): # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = { - "input_ids": input_ids.clone(memory_format=torch.contiguous_format), - "inputs_embeds": None, - } - - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "cross_attention_mask": cross_attention_mask, - } - ) - - # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios - # to compute image hidden states, otherwise they are cached within each cross attn layer - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - model_inputs["aspect_ratio_ids"] = aspect_ratio_ids - model_inputs["aspect_ratio_mask"] = aspect_ratio_mask - - return model_inputs - - def _update_model_kwargs_for_generation( - self, outputs, model_kwargs, is_encoder_decoder, **kwargs - ): - cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - **kwargs, - ) - - # add cross-attn mask for new token - if cross_attention_mask_prev is not None: - model_kwargs["cross_attention_mask"] = torch.cat( - [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], - dim=1, - ) - return model_kwargs From 557467bceb321e1f6cbb7bccf3436d6618246fe1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 17:38:14 -0500 Subject: [PATCH 262/285] add keyboard interrupts to list of unfixable errors Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/gptq/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 550b9849b..0595ad7db 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -217,7 +217,11 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer pipeline model_name = state.model.__class__.__name__ input_names = state.data.calib.dataset.column_names - unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) + unfixable_errors = ( + torch.OutOfMemoryError, + torch._C._LinAlgError, + KeyboardInterrupt, + ) try: run_sequential( state.model, From e158b9bdcbd32e5982452d716268d0f6da955496 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 17:54:44 -0500 Subject: [PATCH 263/285] mistral: remove unnecessary definitions Signed-off-by: Kyle Sayers --- .../transformers/tracing/mistral.py | 1342 +---------------- 1 file changed, 31 insertions(+), 1311 deletions(-) diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py index 7a63099c3..e6d96d912 100644 --- a/src/llmcompressor/transformers/tracing/mistral.py +++ b/src/llmcompressor/transformers/tracing/mistral.py @@ -21,53 +21,31 @@ # vllm-project: no copyright """PyTorch Mistral model.""" -import math -from typing import List, Optional, Tuple, Union - import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.activations import ACT2FN -from transformers.cache_utils import ( - Cache, - DynamicCache, - SlidingWindowCache, - StaticCache, -) -from transformers.generation import GenerationMixin + +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.configuration_mistral import MistralConfig from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, - replace_return_docstrings, ) -if is_flash_attn_2_available(): - from transformers.modeling_flash_attention_utils import _flash_attention_forward +# TRACING: imports +from torch.fx import wrap +from transformers.models.mistral.modeling_mistral import ( + MistralPreTrainedModel, + MistralModel, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralForQuestionAnswering, +) logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" -_CONFIG_FOR_DOC = "MistralConfig" - - -from torch.fx import wrap - +# TRACING: This function is untracable @wrap def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, @@ -80,30 +58,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( config: MistralConfig, past_key_values: Cache, ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`MistralConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask @@ -148,857 +102,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral -class MistralRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - MistralRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class MistralRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) - / self.dim - ) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - freqs = ( - inv_freq_expanded.float() @ position_ids_expanded.float() - ).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class MistralMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj( - self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) - ) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class MistralAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - self.rotary_emb = MistralRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MistralFlashAttention2(MistralAttention): - """ - Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ): - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape( - bsz, q_len, self.num_heads * self.head_dim - ).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral -# TODO(joao): add me back asap :) -class MistralSdpaAttention(MistralAttention): - """ - Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MistralAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -MISTRAL_ATTENTION_CLASSES = { - "eager": MistralAttention, - "flash_attention_2": MistralFlashAttention2, - "sdpa": MistralSdpaAttention, -} - - -# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL -# TODO(joao): add me back asap :) -class MistralDecoderLayer(nn.Module): - def __init__(self, config: MistralConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx - ) - - self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = MistralRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -MISTRAL_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`MistralConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Mistral Model outputting raw hidden-states without any specific head on top.", - MISTRAL_START_DOCSTRING, -) -class MistralPreTrainedModel(PreTrainedModel): - config_class = MistralConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["MistralDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -MISTRAL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices indicating the position of the input sequence tokens in the sequence. Unlike `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Mistral Model outputting raw hidden-states without any specific head on top.", - MISTRAL_START_DOCSTRING, -) -class MistralModel(MistralPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] - - Args: - config: MistralConfig - """ - - def __init__(self, config: MistralConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) - self.layers = nn.ModuleList( - [ - MistralDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] - ) - self._attn_implementation = config._attn_implementation - self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values, - use_cache, - output_attentions, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - +# TRACING: must use wrapped _prepare_4d_causal_attention_mask_with_cache_position +class MistralModel(MistralModel): def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1090,12 +195,11 @@ def _update_causal_mask( return causal_mask -class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - +# TRACING: Must use MistralModel +class MistralForCausalLM(MistralForCausalLM): def __init__(self, config): - super().__init__(config) + super(MistralPreTrainedModel, self).__init__(config) + # TRACING: Must use MistralModel self.model = MistralModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1103,264 +207,26 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC - ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - Example: - - ```python - >>> from transformers import AutoTokenizer, MistralForCausalLM - - >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Ensure tensors are on the same device - shift_labels = shift_labels.to(shift_logits.device) - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Mistral Model transformer with a sequence classification head on top (linear layer). - - [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - MISTRAL_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL -class MistralForSequenceClassification(MistralPreTrainedModel): +# TRACING: Must use MistralModel +class MistralForSequenceClassification(MistralForSequenceClassification): def __init__(self, config): - super().__init__(config) + super(MistralPreTrainedModel, self).__init__(config) self.num_labels = config.num_labels + # TRACING: Must use MistralModel self.model = MistralModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError( - "Cannot handle batch sizes > 1 if no padding token is defined." - ) - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = ( - torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - ) - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[ - torch.arange(batch_size, device=logits.device), sequence_lengths - ] - - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - pooled_logits=pooled_logits, - config=self.config, - ) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - MISTRAL_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL -class MistralForTokenClassification(MistralPreTrainedModel): +# TRACING: Must use MistralModel +class MistralForTokenClassification(MistralForTokenClassification): def __init__(self, config): - super().__init__(config) + super(MistralPreTrainedModel, self).__init__(config) self.num_labels = config.num_labels + # TRACING: Must use MistralModel self.model = MistralModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout @@ -1374,159 +240,13 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ -The Mistral Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - MISTRAL_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model -class MistralForQuestionAnswering(MistralPreTrainedModel): - base_model_prefix = "model" - - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mistral +# TRACING: Must use MistralModel +class MistralForQuestionAnswering(MistralForQuestionAnswering): def __init__(self, config): - super().__init__(config) + super(MistralPreTrainedModel, self).__init__(config) + # TRACING: Must use MistralModel self.model = MistralModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - loss = None - if start_positions is not None and end_positions is not None: - loss = self.loss_function( - start_logits, end_logits, start_positions, end_positions, **kwargs - ) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) From dfadc114ef3d3cfb34f32ff0aa8663a1a48b7b01 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 3 Jan 2025 18:24:58 -0500 Subject: [PATCH 264/285] remove propagate_error argument Signed-off-by: Kyle Sayers --- .../modifiers/quantization/gptq/base.py | 2 -- .../pipelines/layer_sequential/pipeline.py | 21 ++++++-------- .../pipelines/sequential/pipeline.py | 28 ++++++++----------- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 0595ad7db..69e3c0975 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -228,7 +228,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.sequential_targets, self.ignore, state.data.calib, - propagate_error=True, ) return True @@ -244,7 +243,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: state.model, self.sequential_targets, state.data.calib, - propagate_error=True, ) return True diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 73ef421f9..eb3ea6db5 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext from typing import List import torch @@ -20,7 +19,6 @@ def run_pipeline( model: torch.nn.Module, sequential_targets: List[str], dataloader: torch.utils.data.DataLoader, - propagate_error: bool, ): """ Run a layer-wise sequential data pipeline. @@ -50,18 +48,15 @@ def run_pipeline( calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" prop_desc = f"({layer_index + 1}/{num_layers}): Propagating" - if propagate_error: - # do an preliminary pass to trigger modifier hooks - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): - inputs = intermediates.fetch(batch_index) - layer(**inputs) + # do an preliminary pass to trigger modifier hooks + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_index) + layer(**inputs) - # if using propagate_error, then this pass does not trigger modifier hooks - # and is only used for capturing intermediates - # otherwise, this pass triggers modifier hooks and captures intermediates - with HooksMixin.disable_hooks() if propagate_error else nullcontext(): - desc = prop_desc if propagate_error else calib_desc - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): + # this pass does not trigger modifier hooks + # and is only used for capturing outputs from the newly compressed modules + with HooksMixin.disable_hooks(): + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc): inputs = intermediates.fetch(batch_index) output = layer(**inputs) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f3e0a762f..ea9185967 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext from typing import List import torch @@ -19,15 +18,13 @@ def run_pipeline( sequential_targets: List[str], ignore: List[str], dataloader: torch.utils.data.DataLoader, - propagate_error: bool, ): """ Run a sequential data pipeline. 1. The model is partitioned into subgraphs according to `sequential_targets` - 2. Data passes through each subgraph sequentially. If `propagate_error` is enabled, - then data is passed through each subgraph twice, once to trigger calibration - hooks, then a second time in order to capture activations after quantization - has occurred through the hooks. + 2. Data passes through each subgraph sequentially. Data is passed through each + subgraph twice, once to trigger calibration hooks, then a second time in order + to capture activations after quantization has occurred through the hooks. 3. The intermediate activations between each subgraph are cached and offloaded to the cpu between each batch in order to save memory @@ -57,18 +54,15 @@ def run_pipeline( # compile subgraph forward function forward_function = subgraph.compile_forward() - if propagate_error: - # do an preliminary pass to trigger modifier hooks - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): - inputs = intermediates.fetch(batch_index, subgraph.input_names) - forward_function(model, **inputs) + # do an preliminary pass to trigger modifier hooks + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_index, subgraph.input_names) + forward_function(model, **inputs) - # if using propagate_error, then this pass does not trigger modifier hooks - # and is only used for capturing intermediates - # otherwise, this pass triggers modifier hooks and captures intermediates - with HooksMixin.disable_hooks() if propagate_error else nullcontext(): - desc = prop_desc if propagate_error else calib_desc - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc): + # this pass does not trigger modifier hooks + # and is only used for capturing outputs from the newly compressed modules + with HooksMixin.disable_hooks(): + for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc): inputs = intermediates.fetch(batch_index, subgraph.input_names) output = forward_function(model, **inputs) From d1467714461cfe5ec1986d4f7a607a4cbe0b8697 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 05:42:42 +0000 Subject: [PATCH 265/285] pipeline docstrings Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/basic/pipeline.py | 8 ++++++++ src/llmcompressor/pipelines/layer_sequential/pipeline.py | 6 ++++-- src/llmcompressor/pipelines/sequential/pipeline.py | 3 ++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 142698967..0ac0afbc2 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -11,6 +11,14 @@ def run_pipeline(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader): + """ + Run a basic data pipeline. + + Batches are fetched from the data loader and are used to perform forward passes + through the model. This pipeline is typically used for basic model calibration + and, unlike the sequential pipelines, does not propagate compression error when + used to calibrate model compression + """ model_device = get_execution_device(model) with calibration_forward_context(model): diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index eb3ea6db5..a071558b2 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -21,7 +21,8 @@ def run_pipeline( dataloader: torch.utils.data.DataLoader, ): """ - Run a layer-wise sequential data pipeline. + Run a layer-wise sequential data pipeline according to the following steps: + 1. Layers are identified according to `sequential_targets` 2. A hook is attached to the first layer. This hook raises an exception which is then caught and used to capture the input arguments to the first layer @@ -33,7 +34,8 @@ def run_pipeline( to the next layer. This is violated by encoder-decoder architectures among others. If your model architecture violates these assumptions, consider using the sequential - pipeline (see llmcompressor.pipelines.sequential) + pipeline (see llmcompressor.pipelines.sequential). Architectures which are known to + fail these assumptions include GPT-J and most vision language models """ # find layers layers = match_modules(model, sequential_targets) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index ea9185967..fcf1d88b0 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -20,7 +20,8 @@ def run_pipeline( dataloader: torch.utils.data.DataLoader, ): """ - Run a sequential data pipeline. + Run a sequential data pipeline according to the following steps: + 1. The model is partitioned into subgraphs according to `sequential_targets` 2. Data passes through each subgraph sequentially. Data is passed through each subgraph twice, once to trigger calibration hooks, then a second time in order From bb77a449aaaed16933c79db82d1ab89a85eeeb38 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 05:43:57 +0000 Subject: [PATCH 266/285] add gptq lifecycle docstring Signed-off-by: Kyle Sayers --- .../modifiers/quantization/gptq/base.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 69e3c0975..9f916d035 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -60,6 +60,18 @@ class GPTQModifier(Modifier, HooksMixin): | group_size: 128 | actorder: False + Lifecycle: + - on_initialize_structure + - _build_quant_modifier + - on_initialize + - register_hook(module, compress_module, "forward") + - run_sequential / run_layer_sequential / run_basic + - make_empty_hessian + - accumulate_hessian + - quantize_weight + - on_finalize + - remove_hooks() + - model.apply(freeze_module_quantization) :param sequential_targets: list of layer names to compress during GPTQ, or '__ALL__' to compress every layer in the model From 14f5d8836870c9e492203d6ba6303ca3c1e39a5d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 06:00:14 +0000 Subject: [PATCH 267/285] layer sequential helpers docstrings Signed-off-by: Kyle Sayers --- .../pipelines/layer_sequential/helpers.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py index 8004539f7..94f101dc8 100644 --- a/src/llmcompressor/pipelines/layer_sequential/helpers.py +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -19,6 +19,13 @@ def match_modules(model: Module, target_names: List[str]) -> List[Module]: + """ + Find all submodules which match the `target_names` and sort them by name + + :param model: model to search for submodules in + :param target_names: patterns of submodule names to match + :return: list of submodules + """ names_layers = [ (name, module) for name, module in model.named_modules() @@ -35,6 +42,21 @@ def capture_first_layer_intermediates( dataloader: DataLoader, mask_padding: bool = True, ) -> IntermediatesCache: + """ + Captures the intermediate activations directly before the first model layer. + This is meant to capture any model preprocessing before model layers are executed + + Note that if any modules compressed prior to the execution of the first layer, the + compression error induced by compressing those modules will not be propagated to + subsequent activations, as they would be for modules which are compressed within + a layer + + :param model: model containing layers + :param layers: list of layer submodules in the model + :param dataloader: dataloader of calibration inputs + :param mask_padding: zero out padding tokens if True. This affects modifiers such as + GPTQ and SparseGPT + """ model_device = get_execution_device(model) intermediates = IntermediatesCache.empty(len(dataloader), torch.device("cpu")) first_layer = layers[0] @@ -64,6 +86,14 @@ def capture_first_layer_intermediates( def to_next_layer_kwargs(args: Tuple[Any, ...], next_layer: Module) -> Dict[str, Any]: + """ + Convert a list of arguments to a dictionary of keyword arguments which match the + next layer's function signature + + :param args: list of argument values + :param next_layer: the next layer whose function signature must be matched + :return: dictionary mapping function signature keywords to argument values + """ signature = inspect.signature(next_layer.forward) return args_to_kwargs(args, signature) From fde309acbe180244c7a627dede64232b17c8a6a9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 06:14:35 +0000 Subject: [PATCH 268/285] update comments Signed-off-by: Kyle Sayers --- .../pipelines/sequential/helpers.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index e8deeb9f3..f7768b588 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -77,12 +77,11 @@ def trace_subgraphs( def get_tracer( model: Module, sequential_targets: Set[Module], ignore: Set[Module] ) -> HFTracer: - offloaded_modules = set( - module for module in model.modules() if has_offloaded_params(module) - ) + offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m)) class SequentialTracer(HFTracer): def create_arg(self, a: Any) -> Argument: + # special extension allows models which depend on config values to be traced if isinstance(a, PretrainedConfig): kwargs = {k: self.create_arg(v) for k, v in a.to_dict().items()} return self.create_node("call_function", a.__class__, (), kwargs) @@ -90,11 +89,11 @@ def create_arg(self, a: Any) -> Argument: else: return super().create_arg(a) - # Treat as leaf, skip tracing inside this module def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: + # skip tracing sequential targets, offloaded modules, and ignored modules return ( module in sequential_targets - or module in offloaded_modules + or module in offloaded_modules # TODO: may be unnecessary or module in ignore or super().is_leaf_module(module, module_qualified_name) ) @@ -142,7 +141,8 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List } partition_index = 0 # global counter - # start with graph input nodes + # start with graph input nodes, + # but delay the `get_attr` nodes as long as possible queue = deque( node for node in graph.graph.nodes @@ -166,9 +166,9 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List if remaining_indegrees[user] == 0: queue.append(user) - # a perfect implementation would involve implicitly consolidating partition indices + # an ideal implementation would involve implicitly consolidating partition indices # so that each node is assigned to the maximum partition possible (in order to delay - # execution as long as possible), but the current implementation covers the most + # execution as long as possible), but saving these nodes for last covers the most # common and costly case (get_attr) for node in graph.graph.find_nodes(op="get_attr"): user_partitions = [] @@ -197,7 +197,6 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap new_input_nodes = { input_node for node in partition_nodes - # if node.op != "get_attr" for input_node in node.all_input_nodes if input_node not in partition_nodes and input_node.op } From e6a8fa8c10ef2fb2e95260b7a7a9947f35e7ffcf Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 06:53:35 +0000 Subject: [PATCH 269/285] sequential helpers docstrings Signed-off-by: Kyle Sayers --- .../pipelines/sequential/helpers.py | 50 ++++++++++++++++--- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index f7768b588..a7660238b 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -1,7 +1,7 @@ import inspect from collections import deque from dataclasses import dataclass -from typing import Any, Dict, List, Set +from typing import Any, Callable, Dict, List, Set from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches @@ -20,11 +20,25 @@ @dataclass class Subgraph: + """ + Dataclass specifying an executable subgraph of a model graph + + :param graph: subgraph of model graph + :param input_names: argument names of the compiled forward function + :param consumed_names: argument names which are not used by any subsequent subgraphs + and can therefore be deleted from the intermediates cache + """ + graph: Graph input_names: Set[str] consumed_names: Set[str] - def compile_forward(self): + def compile_forward(self) -> Callable[[Any], Any]: + """ + Generate and compile code for executing this subgraph + + :return: function which, when called, executes this subgraph + """ code = self.graph.python_code("self") exec(code.src, code.globals) return code.globals.get("forward") @@ -36,6 +50,18 @@ def trace_subgraphs( sequential_targets: List[str], ignore: List[str], ) -> List[Subgraph]: + """ + Trace a model to produce subgraphs, where each sequential target belongs to exactly + one subgraph and where executing each subgraph in order is equivalent to executing + the original model + + :param model: model being traced + :param sample_input: inputs whose values will change during execution but whose + __len__, __bool__, and __contains__ values are assumed constant across batches + :param sequential_targets: list of patterns specifying sequential targets + :param ignore: list of patterns specifying modules to ignore during tracing + :return: a list of Subgraphs in order of execution + """ # find modules sequential_targets = match_modules(model, sequential_targets) ignore = match_modules(model, ignore) @@ -77,7 +103,19 @@ def trace_subgraphs( def get_tracer( model: Module, sequential_targets: Set[Module], ignore: Set[Module] ) -> HFTracer: + """ + Get a tracer specialized for the given model. The resulting tracer will not trace + inside of sequential targets, ignored targets, or offloaded modules. + + Tracing within sequential targets and ignored targets is unnecessary, and tracing + within offloaded modules may result in meta tensors being added to the model graph + + :param model: model being traced + :param sequential_targets: modules which are sequential targets + :param ignore: modules which are ignored + """ offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m)) + skip_trace_modules = sequential_targets | offloaded_modules | ignore class SequentialTracer(HFTracer): def create_arg(self, a: Any) -> Argument: @@ -90,12 +128,8 @@ def create_arg(self, a: Any) -> Argument: return super().create_arg(a) def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: - # skip tracing sequential targets, offloaded modules, and ignored modules - return ( - module in sequential_targets - or module in offloaded_modules # TODO: may be unnecessary - or module in ignore - or super().is_leaf_module(module, module_qualified_name) + return module in skip_trace_modules or super().is_leaf_module( + module, module_qualified_name ) return SequentialTracer() From 954cd4ebbca90a6fda52ecf5ac2fc880033d28cb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 07:19:12 +0000 Subject: [PATCH 270/285] more docstrings Signed-off-by: Kyle Sayers --- .../pipelines/sequential/helpers.py | 68 +++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index a7660238b..2364e8042 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -136,6 +136,17 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: def populate_concrete_args(model: Module, sample_input: Dict) -> Dict: + """ + Creates concrete args which, unlike the equivalent function provided by + transformers.utils.fx, creates default values for variadic arguments, which are + needed by some models. + + :param model: model being traced + :param sample_input: values used to symbolically trace the model. All arguments + to the model.forward function which are not in the sample_input are considered + concrete args + :return: dictionary mapping concrete argument names to their default values + """ sig = inspect.signature(model.forward) concrete_args = {} @@ -156,7 +167,15 @@ def populate_concrete_args(model: Module, sample_input: Dict) -> Dict: return concrete_args -def get_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]: +def find_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]: + """ + Find all nodes whose execution is equivalent to executing the target modules. + Note that these nodes are guaranteed to be treated as leaf nodes by SequentialTracer + + :param graph: graph containing target nodes + :param targets: modules whose nodes are being searched for + :return: set of all nodes which call the target modules + """ return set( node for node in graph.graph.nodes @@ -165,8 +184,18 @@ def get_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]: def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List[Node]]: + """ + Partition the graph into partitions such that each `target` belongs to exactly one + partition and executing each partition depends only on intermediate values produced + by executing the partitions before it. + + :param graph: graph being partitioned + :param targets: target modules which will be assigned to disjoint partitions + :return: list of partitions, where each partition is a list of nodes belong to that + partition + """ assert check_assumption(graph.graph) - target_nodes = get_target_nodes(graph, targets) + target_nodes = find_target_nodes(graph, targets) partitions: List[List[Node]] = [[]] remaining_indegrees = { @@ -219,6 +248,17 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgraph]: + """ + Convert each partition into a Subgraph. Each Subgraph returns a dictionary mapping + of output node names to their computed values. Note that the `consumed_names` + attribute of each Subgraph remains empty, to be later populated by + `trace_consumed_names` + + :param model: model which owns the produced Subgraphs + :param partitions: list of partitions, where each partition is a list of nodes + belong to that partition + :return: list of subgraphs in order of execution + """ subgraphs = [] # create subgraphs @@ -266,7 +306,13 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap return subgraphs -def trace_consumed_names(subgraphs: List[Dict[str, Any]]): +def trace_consumed_names(subgraphs: List[Subgraph]): + """ + Populate the `consumed_names` attribute of each Subgraph according to when inputs + are last used in order to vacate the `intermediates` cache and save memory + + :param subgraphs: list of subgraphs with empty `consumed_names` attributes + """ # populate consumed_names according to when inputs are last used # in order to vacate the `intermediates` cache and save memory all_input_names = set().union(*(subgraph.input_names for subgraph in subgraphs)) @@ -276,10 +322,17 @@ def trace_consumed_names(subgraphs: List[Dict[str, Any]]): subgraph.consumed_names.add(input_name) break else: - assert False + raise ValueError(f"Could not find input name {input_name} in subgraphs") def check_assumption(graph: Graph) -> bool: + """ + Checks that a graph is not malformed + + :param graph: graph being checked + :return: True if node.users and node.all_input_nodes have bidirectional + relationships, False otherwise + """ for node in graph.nodes: for user in node.users: if node not in user.all_input_nodes: @@ -298,6 +351,13 @@ def check_assumption(graph: Graph) -> bool: def match_modules(model: Module, target_names: List[str]) -> Set[Module]: + """ + Find modules whose names matach the patterns given by `target_names` + + :param model: model containing submodules to find + :param target_names: target patterns to find + :return: all submodules matching `target_names` + """ return set( module for name, module in model.named_modules() From 00309e9eceee195bb8bf1cd5728e74402e20c52b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 07:42:06 +0000 Subject: [PATCH 271/285] IntermediatesCache docstrings Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/cache.py | 53 ++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index b4e8a440c..418fb9e4d 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -8,11 +8,29 @@ @dataclass class IntermediateValue: + """ + Dataclass which recursively defines offloaded values and which device to onload to + + :param value: either an offloaded Tensor, an primative value, or a recursable value + :param device: if the value is a Tensor, then the device to onload the tensor to, + otherwise None + """ + value: Union[torch.Tensor, "IntermediateValue", Any] device: Union[torch.device, None] class IntermediatesCache: + """ + Cache which stores intermediate values (activations) produced by batched, sequential + execution of models. Values are offloaded to the `offload_device` when stored in + the cache and onloaded to their original device when fetched from the cache + + Currently supports nested offloading of dataclass instances and tuples + + Construct using `empty` and `from_dataloader` class methods + """ + batch_intermediates: List[Dict[str, IntermediateValue]] offload_device: torch.device @@ -26,6 +44,12 @@ def __init__( @classmethod def empty(cls, num_batches: int, offload_device: torch.device): + """ + Construct an empty cache + + :param num_batches: the expected number of batches to be stored + :param offload_device: device to offload values to + """ batch_intermediates = [{} for _ in range(num_batches)] return cls(batch_intermediates, offload_device) @@ -37,6 +61,15 @@ def from_dataloader( mask_padding: bool = True, offload_device: torch.device = "cpu", ): + """ + Initialize a cache with data from the provided dataloader + + :param dataloader: dataloader which generates values to be cached + :param model_device: device which values will be onloaded to when fetched + :param mask_padding: zero out padding tokens if True. This affects modifiers + such as GPTQ and SparseGPT + :param offload_device: device to offload values to + """ batch_intermediates = [ { key: ( @@ -57,6 +90,13 @@ def from_dataloader( def fetch( self, batch_index: int, input_names: Optional[List[str]] = None ) -> Dict[str, Any]: + """ + Fetch values belonging to a batch + + :param batch_index: index of batch whose values are being fetched + :param input_names: list of keys whose values are being fetched + :return: dictionary mapping keys to onloaded values + """ intermediates = self.batch_intermediates[batch_index] return { @@ -66,10 +106,23 @@ def fetch( } def update(self, batch_index: int, values: Dict[str, Any]): + """ + Update/put values belonging to a batch + + :param batch_index: index of batch whose values will be updated + :param values: dictionary mapping keys to values used for update + """ intermediates = {k: self._offload_value(v) for k, v in values.items()} self.batch_intermediates[batch_index].update(intermediates) def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None): + """ + Delete values from the cache + + :param batch_index: index of batch whose values will be deleted + :param consumed_names: list of keys whose values will be deleted, defaults to + removing all keys + """ intermediates = self.batch_intermediates[batch_index] if consumed_names is None: From 57e8f2168d10545409c7d85cfc7e7ac89da04eb6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 07:45:08 +0000 Subject: [PATCH 272/285] free hessians on finalize Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/gptq/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 9f916d035..136bfdb9a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -282,6 +282,8 @@ def on_finalize(self, state: "State", **kwargs) -> bool: self._quantization_modifier.finalize(state, **kwargs) self.remove_hooks() + self._hessians = dict() + self._num_samples = dict() state.model.apply(freeze_module_quantization) return True From 378afb3350356ae896ce425806b0faa4981de7e4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 17:46:55 +0000 Subject: [PATCH 273/285] remove unnecessary examples Signed-off-by: Kyle Sayers --- .../quantization_w4a16/minicpm_example.py | 82 ----------------- .../quantization_w4a16/mixtral_example.py | 89 ------------------- 2 files changed, 171 deletions(-) delete mode 100644 examples/quantization_w4a16/minicpm_example.py delete mode 100644 examples/quantization_w4a16/mixtral_example.py diff --git a/examples/quantization_w4a16/minicpm_example.py b/examples/quantization_w4a16/minicpm_example.py deleted file mode 100644 index 826d01d34..000000000 --- a/examples/quantization_w4a16/minicpm_example.py +++ /dev/null @@ -1,82 +0,0 @@ -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot - -# Select model and load it. -MODEL_ID = "openbmb/MiniCPM3-4B" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, - device_map="auto", - torch_dtype="auto", - trust_remote_code=True, -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" - -# Select number of samples. 512 samples is a good place to start. -# Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# Configure the quantization algorithm to run. -# * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) - -# Apply algorithms. -oneshot( - model=model, - tokenizer=tokenizer, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, -) - -# Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) - -# Confirm generations of the quantized model look sane. -print("\n\n") -print("========== SAMPLE GENERATION ==============") -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0])) -print("==========================================\n\n") diff --git a/examples/quantization_w4a16/mixtral_example.py b/examples/quantization_w4a16/mixtral_example.py deleted file mode 100644 index 10834725a..000000000 --- a/examples/quantization_w4a16/mixtral_example.py +++ /dev/null @@ -1,89 +0,0 @@ -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.transformers import oneshot -from llmcompressor.transformers.compression.helpers import calculate_offload_device_map - -# Select model and load it. -MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" -device_map = calculate_offload_device_map( - MODEL_ID, - reserve_for_hessians=True, - num_gpus=1, - torch_dtype="auto", -) - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, - device_map=device_map, - torch_dtype="auto", -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" - -# Select number of samples. 512 samples is a good place to start. -# Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) -ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# Configure the quantization algorithm to run. -# * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier( - targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*block_sparse_moe.gate"] -) - -# Apply algorithms. -oneshot( - model=model, - dataset=ds, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, -) - -# Confirm generations of the quantized model look sane. -print("\n\n") -print("========== SAMPLE GENERATION ==============") -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0])) -print("==========================================\n\n") - -# Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) From 83b81beb3ee5b0b18b5a522e138a04a3659dad0c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 18:10:49 +0000 Subject: [PATCH 274/285] make diff closer to original implementation Signed-off-by: Kyle Sayers --- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index decb3a33d..45290d2f6 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -88,22 +88,21 @@ def quantize_weight( actorder = quant_args.actorder final_shape = module.weight.shape final_dtype = module.weight.dtype - module_class = type(module) W = module.weight.clone() H = hessians_dict[module] # unfortunately python does not have a `move` keyword del hessians_dict[module] # so we have to delete the original reference manually # create observer for calculating quantization parameters observer = Observer.load_from_registry( - "minmax", + quant_args.observer, quantization_args=quant_args, averaging_constant=1.0, # ignore moving average ) # standardize shape and dtype - if module_class == torch.nn.Conv2d: + if isinstance(module, torch.nn.Conv2d): W = W.flatten(1) - elif module_class == transformers.Conv1D: + elif isinstance(module, transformers.Conv1D): W.transpose_(0, 1) W = W.to(dtype=GPTQ_PRECISION) num_rows = W.shape[0] @@ -263,7 +262,7 @@ def quantize_weight( if not has_gidx: g_idx = None - if module_class == transformers.Conv1D: + if isinstance(module, transformers.Conv1D): W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) From 5363d401f21e58908b23ce162a10665e3fa83210 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 18:29:05 +0000 Subject: [PATCH 275/285] use original mask padding function Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/utils/pytorch_helpers.py | 6 ++++-- src/llmcompressor/pipelines/cache.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 00165c98a..c9869f267 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -34,12 +34,14 @@ def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.T """ Apply a mask to the input ids of a batch. This is used to zero out padding tokens so they do not contribute to the hessian calculation in the - SparseGPT algorithm + GPTQ and SparseGPT algorithms + + Assumes that `attention_mask` only contains zeros and ones :param batch: batch to apply padding to if it exists :return: batch with padding zeroed out in the input_ids """ - batch["input_ids"].masked_fill_(batch["attention_mask"] == 0, 0) + batch["input_ids"] = batch["input_ids"] * batch["attention_mask"] return batch diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 418fb9e4d..6369e3f89 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -187,4 +187,5 @@ def _mask_padding( # some attention masks, such as those from pixtral, are are 4d attention_mask = attention_mask[0, 0, 0].unsqueeze(0) - return input_ids.masked_fill_(torch.logical_not(attention_mask), 0) + # Assumes that `attention_mask` only contains zeros and ones + return input_ids * attention_mask From ae8968849db669c96d59f6108487a7b584807a54 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 4 Jan 2025 18:35:26 +0000 Subject: [PATCH 276/285] reduce diff Signed-off-by: Kyle Sayers --- .../transformers/compression/helpers.py | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index fee98137f..35ab51220 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -114,8 +114,8 @@ def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str] def hessian_memory_requirements(model: torch.nn.Module) -> int: """ Determines the number of bytes needed to store Hessian data for a single - transformer layer in model. This is used for reserving memory for GPTQModifier - or SparseGPTModifier + transformer layer in model. This is used for reserving memory for GPTQ + quantization :param model: model to calculate requirements for :return: number of bytes required to reserve for GPTQ on a single layer @@ -192,6 +192,7 @@ def custom_offload_device_map( memory_limits = {device: max_memory_per_gpu for device in range(num_gpus)} memory_limits["cpu"] = max_cpu_memory + device_map = {} with init_empty_weights(): dummy_model = model_cls.from_pretrained(model_stub, **model_kwargs) device_map = infer_auto_device_map( @@ -199,6 +200,7 @@ def custom_offload_device_map( max_memory=memory_limits, no_split_module_classes=dummy_model._no_split_modules, ) + del dummy_model return device_map @@ -206,8 +208,7 @@ def custom_offload_device_map( def calculate_offload_device_map( model_stub: str, reserve_for_hessians=False, - num_gpus: Optional[int] = None, - gpu_ids: Optional[List[int]] = None, + num_gpus: int = 1, torch_dtype: torch.dtype = torch.float16, model_cls: Type = AutoModelForCausalLM, **model_kwargs, @@ -217,33 +218,23 @@ def calculate_offload_device_map( into account extra memory required for quantization and (optionally) GPTQ hessians :param model_stub: local path or HF stub to calculate mapping for - :param reserve_for_hessians: whether to reserve memory for GPTQ/OBCQ - :param num_gpus: number of gpus to utilize, defaults to max available - :param gpu_ids: list of gpu device ids to utilize, overrides num_gpus if provided - :param torch_dtype: datatype in which model weights are to be loaded with + :param reserve_for_hessians: whether to reserve memory for GPTQ + :param num_gpus: number of gpus to utilize :param model_cls: model class to use when initializing model structure, default is AutoModelForCausalLM :param model_kwargs: keyword arguments to pass to model initializer :return: memory mapping for layers of model_stub to be passed to from_pretrained() """ max_cpu_memory = psutil.virtual_memory().available + max_gpu_memory = torch.cuda.mem_get_info(0)[0] available_gpus = torch.cuda.device_count() - if gpu_ids is None: - if num_gpus is None: - num_gpus = available_gpus - gpu_ids = range(num_gpus) - else: - num_gpus = len(gpu_ids) - - if num_gpus > available_gpus: + if available_gpus < num_gpus: raise ValueError( f"Requested {num_gpus} GPUs but only {available_gpus} are available." ) + max_gpu_memory = [max_gpu_memory] * num_gpus - max_gpu_memory = { - device_id: torch.cuda.mem_get_info(device_id)[0] for device_id in gpu_ids - } - + device_map = {} with init_empty_weights(): dummy_model = model_cls.from_pretrained( model_stub, torch_dtype=torch_dtype, **model_kwargs @@ -256,7 +247,7 @@ def calculate_offload_device_map( memory_limits = { idx: (max_memory - reserved_memory) - for idx, max_memory in max_gpu_memory.items() + for idx, max_memory in enumerate(max_gpu_memory) } memory_limits["cpu"] = max_cpu_memory @@ -265,6 +256,7 @@ def calculate_offload_device_map( max_memory=memory_limits, no_split_module_classes=dummy_model._no_split_modules, ) + del dummy_model return device_map From d3eebfe391f8fc913c90d97292f84c39ea721b10 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Jan 2025 21:00:49 +0000 Subject: [PATCH 277/285] replace list comprehesion Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/cache.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 6369e3f89..57c9b1486 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -59,7 +59,7 @@ def from_dataloader( dataloader: torch.utils.data.DataLoader, model_device: torch.device, mask_padding: bool = True, - offload_device: torch.device = "cpu", + offload_device: torch.device = torch.device("cpu"), ): """ Initialize a cache with data from the provided dataloader @@ -70,20 +70,16 @@ def from_dataloader( such as GPTQ and SparseGPT :param offload_device: device to offload values to """ - batch_intermediates = [ - { - key: ( - IntermediateValue( - value=cls._mask_padding(value, batch["attention_mask"]), - device=model_device, - ) - if mask_padding and key == "input_ids" - else IntermediateValue(value=value, device=model_device) - ) - for key, value in batch.items() - } - for batch in tqdm.tqdm(dataloader, desc="Preparing intermediates cache") - ] + # note: list comprehesion was found to not improve performance + batch_intermediates = [] + for batch in tqdm.tqdm(dataloader, desc="Preparing intermediates cache"): + intermediate = {} + for key, value in batch.items(): + if mask_padding and key == "input_ids": + value = cls._mask_padding(value, batch["attention_mask"]) + intermediate[key] = IntermediateValue(value=value, device=model_device) + + batch_intermediates.append(intermediate) return cls(batch_intermediates, offload_device) From 412086c7b3973870afd80dca9898757ccfa5b5b4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Jan 2025 21:01:21 +0000 Subject: [PATCH 278/285] nit: only pass first layer Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/layer_sequential/helpers.py | 5 ++--- src/llmcompressor/pipelines/layer_sequential/pipeline.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py index 94f101dc8..762516a1c 100644 --- a/src/llmcompressor/pipelines/layer_sequential/helpers.py +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -38,7 +38,7 @@ def match_modules(model: Module, target_names: List[str]) -> List[Module]: def capture_first_layer_intermediates( model: Module, - layers: List[Module], + first_layer: Module, dataloader: DataLoader, mask_padding: bool = True, ) -> IntermediatesCache: @@ -52,14 +52,13 @@ def capture_first_layer_intermediates( a layer :param model: model containing layers - :param layers: list of layer submodules in the model + :param first_layer: the first layer of the model :param dataloader: dataloader of calibration inputs :param mask_padding: zero out padding tokens if True. This affects modifiers such as GPTQ and SparseGPT """ model_device = get_execution_device(model) intermediates = IntermediatesCache.empty(len(dataloader), torch.device("cpu")) - first_layer = layers[0] signature = inspect.signature(first_layer.forward) with calibration_forward_context(model), early_stop_hook(first_layer): diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index a071558b2..27836c921 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -42,7 +42,7 @@ def run_pipeline( with calibration_forward_context(model): # prepare intermediates cache - intermediates = capture_first_layer_intermediates(model, layers, dataloader) + intermediates = capture_first_layer_intermediates(model, layers[0], dataloader) num_layers = len(layers) for layer_index, layer in enumerate(layers): From 84333045dd70ce83e9684e643ffb1fa7a06a90a2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Jan 2025 21:01:42 +0000 Subject: [PATCH 279/285] revert changes to tensors_to_device Signed-off-by: Kyle Sayers --- src/llmcompressor/pytorch/utils/helpers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 2f9074f9b..1a0724e6c 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -289,7 +289,8 @@ def tensors_to_device( Default function for putting a tensor or collection of tensors to the proper device. Returns the tensor references after being placed on the proper device. - Recursive cases: + Supported use cases: + - single tensor - Dictionary of single tensors - Dictionary of iterable of tensors - Dictionary of dictionary of tensors @@ -319,7 +320,9 @@ def tensors_to_device( if isinstance(tensors, Iterable): return [tensors_to_device(tens, device) for tens in tensors] - return tensors + raise ValueError( + "unrecognized type for tensors given of {}".format(tensors.__class__.__name__) + ) def tensors_to_precision( From 07b3cc35878c96bb81cfd7f941b9cbaad27fbb7d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Jan 2025 21:04:44 +0000 Subject: [PATCH 280/285] type hint intermediates cache for clarity Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/layer_sequential/pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 27836c921..a1d38e6f0 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -5,6 +5,7 @@ import tqdm from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( capture_first_layer_intermediates, match_modules, @@ -42,7 +43,9 @@ def run_pipeline( with calibration_forward_context(model): # prepare intermediates cache - intermediates = capture_first_layer_intermediates(model, layers[0], dataloader) + intermediates: IntermediatesCache = capture_first_layer_intermediates( + model, layers[0], dataloader + ) num_layers = len(layers) for layer_index, layer in enumerate(layers): From 895b409a9753568a61903262d4fc3b1f7f57f780 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Jan 2025 21:07:56 +0000 Subject: [PATCH 281/285] make hessian instability a _LinAlgError so it can be caught by gptq fallbacks Signed-off-by: Kyle Sayers --- .../modifiers/quantization/gptq/utils/gptq_quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py index 45290d2f6..6f0ae60fb 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -161,7 +161,7 @@ def quantize_weight( H = torch.linalg.cholesky(H, upper=True) Hinv = H except torch._C._LinAlgError: - raise ValueError( + raise torch._C._LinAlgError( "Failed to invert hessian due to numerical instability. Consider " "increasing GPTQModifier.dampening_frac, increasing the number " "of calibration samples, or shuffling the calibration dataset" From 336e0648b55478d55783a8337ef2b2e5ce84ebc3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Jan 2025 23:55:19 +0000 Subject: [PATCH 282/285] defer chatglm for later Signed-off-by: Kyle Sayers --- .../transformers/tracing/__init__.py | 4 - .../transformers/tracing/glm/LICENSE | 84 - .../tracing/glm/configuration_chatglm.py | 68 - .../tracing/glm/modeling_chatglm.py | 1349 ----------------- .../transformers/tracing/glm/visual.py | 182 --- .../transformers/utils/data_collator.py | 11 - 6 files changed, 1698 deletions(-) delete mode 100644 src/llmcompressor/transformers/tracing/glm/LICENSE delete mode 100644 src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py delete mode 100644 src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py delete mode 100644 src/llmcompressor/transformers/tracing/glm/visual.py diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 19abc1a9d..4baa5864d 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,6 +1,3 @@ -from .glm.modeling_chatglm import ( - ChatGLMForConditionalGeneration as TraceableChatGLMForConditionalGeneration, -) from .llava import ( LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration, ) @@ -13,5 +10,4 @@ "TraceableLlavaForConditionalGeneration", "TraceableMllamaForConditionalGeneration", "TraceableMistralForCausalLM", - "TraceableChatGLMForConditionalGeneration", ] diff --git a/src/llmcompressor/transformers/tracing/glm/LICENSE b/src/llmcompressor/transformers/tracing/glm/LICENSE deleted file mode 100644 index 7b7c19f56..000000000 --- a/src/llmcompressor/transformers/tracing/glm/LICENSE +++ /dev/null @@ -1,84 +0,0 @@ -The glm-4-9b License - -1. 定义 - -“许可方”是指分发其软件的 glm-4-9b 模型团队。 -“软件”是指根据本许可提供的 glm-4-9b 模型参数。 - -2. 许可授予 - -根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。 -本许可允许您免费使用本仓库中的所有开源模型进行学术研究,对于希望将模型用于商业目的的用户,需在[这里](https://open.bigmodel.cn/mla/form)完成登记。经过登记的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。 -上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 -如果您分发或提供 THUDM / 智谱AI 关于 glm-4 开源模型的材料(或其任何衍生作品),或使用其中任何材料(包括 glm-4 系列的所有开源模型)的产品或服务,您应: - -(A) 随任何此类 THUDM / 智谱AI 材料提供本协议的副本; -(B) 在相关网站、用户界面、博客文章、关于页面或产品文档上突出显示 “Built with glm-4”。 -如果您使用 THUDM / 智谱AI的 glm-4 开源模型的材料来创建、训练、微调或以其他方式改进已分发或可用的 AI 模型,您还应在任何此类 AI 模型名称的开头添加 “glm-4”。 - -3. 限制 - -您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 -您不得利用本软件从事任何危害国家安全和国家统一,危害社会公共利益及公序良俗,侵犯他人商业秘密、知识产权、名誉权、肖像权、财产权等权益的行为。 -您在使用中应遵循使用地所适用的法律法规政策、道德规范等要求。 - -4. 免责声明 - -本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 -在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 -软件。 - -5. 责任限制 - -除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 -或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 - -6. 争议解决 - -本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 -请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 - -1. Definitions - -“Licensor” means the glm-4-9b Model Team that distributes its Software. -“Software” means the glm-4-9b model parameters made available under this license. - -2. License - -Under the terms and conditions of this license, the Licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. -This license allows you to use all open source models in this repository for free for academic research. For users who wish to use the models for commercial purposes, please do so [here](https://open.bigmodel.cn/mla/form) -Complete registration. Registered users are free to use this model for commercial activities, but must comply with all terms and conditions of this license. -The copyright notice and this license notice shall be included in all copies or substantial portions of the Software. -If you distribute or provide THUDM / Zhipu AI materials on the glm-4 open source model (or any derivative works thereof), or products or services that use any materials therein (including all open source models of the glm-4 series), you should: - -(A) Provide a copy of this Agreement with any such THUDM/Zhipu AI Materials; -(B) Prominently display "Built with glm-4" on the relevant website, user interface, blog post, related page or product documentation. -If you use materials from THUDM/Zhipu AI's glm-4 model to create, train, operate, or otherwise improve assigned or available AI models, you should also add "glm-4" to the beginning of any such AI model name. - -3. Restrictions - -You are not allowed to use, copy, modify, merge, publish, distribute, copy or create all or part of the derivative works of this software for any military or illegal purposes. -You are not allowed to use this software to engage in any behavior that endangers national security and unity, endangers social public interests and public order, infringes on the rights and interests of others such as trade secrets, intellectual property rights, reputation rights, portrait rights, and property rights. -You should comply with the applicable laws, regulations, policies, ethical standards, and other requirements in the place of use during use. - -4. Disclaimer - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE -WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR -OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -5. Limitation of Liability - -EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, -NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, -INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED -OF THE POSSIBILITY OF SUCH DAMAGES. - -6. Dispute Resolution - -This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute -arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. - -Note that the license is subject to update to a more comprehensive version. For any questions related to the license and -copyright, please contact us at license@zhipuai.cn. \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py b/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py deleted file mode 100644 index 2487be8ce..000000000 --- a/src/llmcompressor/transformers/tracing/glm/configuration_chatglm.py +++ /dev/null @@ -1,68 +0,0 @@ -# flake8: noqa -# vllm-project: no copyright -from transformers import PretrainedConfig - - -class ChatGLMConfig(PretrainedConfig): - model_type = "chatglm" - - def __init__( - self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - classifier_dropout=None, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - rope_ratio=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - pre_seq_len=None, - prefix_projection=False, - boi_token_id=None, - eoi_token_id=None, - **kwargs - ): - self.num_layers = num_layers - self.vocab_size = padded_vocab_size - self.padded_vocab_size = padded_vocab_size - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.kv_channels = kv_channels - self.num_attention_heads = num_attention_heads - self.seq_length = seq_length - self.hidden_dropout = hidden_dropout - self.classifier_dropout = classifier_dropout - self.attention_dropout = attention_dropout - self.layernorm_epsilon = layernorm_epsilon - self.rmsnorm = rmsnorm - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm - self.post_layer_norm = post_layer_norm - self.add_bias_linear = add_bias_linear - self.add_qkv_bias = add_qkv_bias - self.bias_dropout_fusion = bias_dropout_fusion - self.multi_query_attention = multi_query_attention - self.multi_query_group_num = multi_query_group_num - self.rope_ratio = rope_ratio - self.apply_query_key_layer_scaling = apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = attention_softmax_in_fp32 - self.fp32_residual_connection = fp32_residual_connection - self.pre_seq_len = pre_seq_len - self.prefix_projection = prefix_projection - self.boi_token_id = boi_token_id - self.eoi_token_id = eoi_token_id - super().__init__(**kwargs) \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py b/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py deleted file mode 100644 index c59b6a0a1..000000000 --- a/src/llmcompressor/transformers/tracing/glm/modeling_chatglm.py +++ /dev/null @@ -1,1349 +0,0 @@ -# flake8: noqa -# vllm-project: no copyright -""" PyTorch GLM-4V model. """ -import math -import sys -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Dict, Any - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging, is_torch_npu_available -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .visual import EVA2CLIPModel -from .configuration_chatglm import ChatGLMConfig - -# TRACING: import wrap -from torch.fx import wrap - -try: - from transformers.utils import ( - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - ) - - if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import ( # noqa - index_first_axis, - pad_input, - unpad_input, - ) -except: - pass - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin' and not is_torch_npu_available(): - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -LANGUAGE_TOKEN_TYPE = 0 -VISION_TOKEN_TYPE = 1 - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" -_CONFIG_FOR_DOC = "ChatGLMConfig" - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 198] = 5e4 - return scores - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config: ChatGLMConfig): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 - self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(kv_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, - config.num_layers * config.kv_channels * config.multi_query_group_num * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - self.rope_ratio = rope_ratio - - def impl(self, seq_length: int, dim: int, device: torch.device, dtype: torch.dtype): - base = 10000 * self.rope_ratio - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) - seq = torch.arange(seq_length, device=inv_freq.device, dtype=torch.float32) - freqs = torch.outer(seq, inv_freq) - # first part even vector components, second part odd vector components, - # 2 * dim in dimension size - emb = torch.cat((freqs, freqs), dim=-1) - return emb - - def forward_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - base = base * self.rope_ratio - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - if self.original_impl: - return self.forward_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) - else: - return self.impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device) - - -@torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [b, np, sq, hn] - b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:, :sq] - xshaped = x.reshape(b, np, sq, rot_dim // 2, 2) - rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, hidden_states: torch.Tensor): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - return (self.weight * hidden_states).to(input_dtype) - - - -class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2: - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.transpose(1, 2).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) - - # [b, np, sq, hn] -> [b * np, sq, hn] - query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) - # [b, np, sk, hn] -> [b * np, sk, hn] - key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer, # [b * np, sq, hn] - key_layer.transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # change view [b * np, sk, hn] - value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [b, sq, np, hn] - context_layer = context_layer.transpose(1, 2).contiguous() - # [b, sq, np, hn] --> [b, sq, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - - return context_layer - -class SdpaAttention(CoreAttention): - def forward(self, query_layer, key_layer, value_layer, attention_mask): - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True, - dropout_p=self.config.attention_dropout if self.training else 0.0) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask, - dropout_p=self.config.attention_dropout if self.training else 0.0) - context_layer = context_layer.transpose(1, 2).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - return context_layer - - -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 -class FlashAttention2(CoreAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward(self, query_states, key_states, value_states, attention_mask): - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - batch_size, query_length = query_states.shape[:2] - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - dropout = self.config.attention_dropout if self.training else 0.0 - # Contains at least one padding token in the sequence - if attention_mask is not None: - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=None, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal - ) - attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous() - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), - indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -CORE_ATTENTION_CLASSES = { - "eager": CoreAttention, - "sdpa": SdpaAttention, - "flash_attention_2": FlashAttention2 -} - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - self.original_rope = config.original_rope - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True - ): - # hidden_states: [b, sq, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # [b, sq, np, hn] -> [b, np, sq, hn] - query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]] - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=2) - value_layer = torch.cat((cache_v, value_layer), dim=2) - if use_cache: - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(2) - key_layer = key_layer.expand( - -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:] - ) - value_layer = value_layer.unsqueeze(2) - value_layer = value_layer.expand( - -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:] - ) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - self.gradient_checkpointing = False - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache, - use_reentrant=False - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - _supports_flash_attn_2 = True - _supports_sdpa = True - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_embeds, past_key_values, padding_mask=None): - batch_size, seq_length, embed_size = input_embeds.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[2] - if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_embeds.device), full_attention_mask), dim=-1) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - def get_multimodal_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -# TRACING: this function is untraceable -@wrap -def is_empty(images_list: Optional[List[List[torch.Tensor]]]): - if images_list is None or len(images_list) == 0: - return True - for image_list in images_list: - if image_list is not None: - return False - return True - - -class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - self.num_layers = config.num_layers - self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) - - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, - original_impl=config.original_rope, - device=device, dtype=config.torch_dtype) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - self.vision = EVA2CLIPModel(config) - - def get_input_embeddings(self): - return self.embedding.word_embeddings - - def set_input_embeddings(self, value): - self.embedding.word_embeddings = value - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.pre_seq_len, - self.num_layers * 2, - self.multi_query_group_num, - self.kv_channels - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - return past_key_values - - def forward( - self, - input_ids: torch.LongTensor = None, - images: torch.Tensor = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """take care of image_encode, position_ids and (attention_mask = None is fine)""" - - # generate mode with past_key_values. the image features are already mapped - if past_key_values is None: - # not allow for inputs_embeds, because we want to process image feature - assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}" - if not is_empty(images): # multi-modality - image_size: int = self.config.vision_config['image_size'] - patch_size: int = self.config.vision_config['patch_size'] - num_patches = (image_size // patch_size // 2) ** 2 - assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}" - inputs_embeds = self.embedding(input_ids) - - images = images.to(dtype=inputs_embeds.dtype) - images_features = self.vision(images) - - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=inputs_embeds.device) - new_input_embeds, new_position_ids = [], [] - - for i in range(len(input_ids)): - input_id = input_ids[i].tolist() - boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( - self.config.eoi_token_id) - # TRACING: Assume that processing and tokenization was done correctly - #assert eoi_token_pos - boi_token_pos == 2 - new_input_embeds.append(torch.cat( - (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device), - inputs_embeds[i, eoi_token_pos + 1:]))) - new_position_ids.append(torch.cat( - (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches), - position_ids[i, eoi_token_pos:]) - )) - inputs_embeds = torch.stack(new_input_embeds, dim=0) - position_ids = torch.stack(new_position_ids, dim=0) - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, - dtype=inputs_embeds.dtype) - if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask], dim=-1) - - if full_attention_mask is None: - # TRACING: Assume only prefill and that the attention mask is full - #if (attention_mask is not None and not attention_mask.all().item()) or (past_key_values and seq_length != 1): - if False: #if (attention_mask is not None and not attention_mask.all().item()) or (past_key_values and seq_length != 1): - if self.training: - # https://github.com/THUDM/GLM-4/issues/264 - new_input_ids, new_attention_mask = [], [] - for i in range(len(input_ids)): - input_id = input_ids[i].tolist() - boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(self.config.eoi_token_id) - # TRACING: Assume that processing and tokenization was done correctly - #assert eoi_token_pos - boi_token_pos == 2 - - new_attention_mask.append(torch.cat( - (attention_mask[i, :boi_token_pos + 1], torch.ones(num_patches).to(attention_mask.device), - attention_mask[i, eoi_token_pos:]))) - - new_input_ids.append(torch.cat( - (input_ids[i, :boi_token_pos + 1], input_ids[i, -1].repeat(num_patches), - input_ids[i, eoi_token_pos:]))) - - attention_mask = torch.stack(new_attention_mask, dim=0) - input_ids = torch.stack(new_input_ids, dim=0) - inputs_embeds = self.embedding(input_ids) - - full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -def _history_to_prompt(history, query): - prompt = '' - flag = False - for i, (old_query, response) in enumerate(history): - prompt += ('<|user|>' if flag else '') + old_query + "<|assistant|>" + response + "<|endoftext|>" - flag = True - prompt += '{}{}<|assistant|>'.format('<|user|>' if flag else '', query) - return prompt - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - cache_name, cache = self._extract_past_from_model_output(outputs) - model_kwargs[cache_name] = cache - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - images: Optional[torch.Tensor] = None, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if attention_mask is not None: - image_size: int = self.config.vision_config['image_size'] - patch_size: int = self.config.vision_config['patch_size'] - num_patches = (image_size // patch_size // 2) ** 2 - new_attention_masks = [] - - # if not image, use this default id - eoi_token_pos = 6 - boi_token_pos = 4 - - for i in range(len(input_ids)): - input_id = input_ids[i].tolist() - if not is_empty(images): - boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( - self.config.eoi_token_id) - # TRACING: Assume that processing and tokenization was done correctly - #assert eoi_token_pos - boi_token_pos == 2 - new_attention_masks.append(torch.cat( - (attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches), - attention_mask[i, eoi_token_pos:]) - )) - attention_mask = torch.stack(new_attention_masks, dim=0) - if not is_first_forward: - if past_key_values is not None: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "images": images, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True, - "use_cache": use_cache - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - images: List[List[torch.Tensor]] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - images=images, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[:, -1:] - lm_logits = self.transformer.output_layer(hidden_states) - - loss = None - if labels is not None: - new_labels = [] - for i in range(len(input_ids)): - input_id = input_ids[i].tolist() - boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index( - self.config.eoi_token_id) - # TRACING: Assume that processing and tokenization was done correctly - #assert eoi_token_pos - boi_token_pos == 2 - - new_labels.append(torch.cat( - ( - labels[i, :boi_token_pos + 1], - torch.tensor([-100]).to(labels.device).to(labels.dtype).repeat(1600), - labels[i, eoi_token_pos:]))) - - labels = torch.stack(new_labels, dim=0) - lm_logits = lm_logits.to(torch.float32) - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - -class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.num_labels = config.num_labels - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - - self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) - if config.classifier_dropout is not None: - self.dropout = nn.Dropout(config.classifier_dropout) - else: - self.dropout = None - self.config = config - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - full_attention_mask=full_attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - pooled_hidden_states = hidden_states[-1] - if self.dropout is not None: - pooled_hidden_states = self.dropout(pooled_hidden_states) - logits = self.classifier_head(pooled_hidden_states) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze().float(), labels.squeeze()) - else: - loss = loss_fct(logits.float(), labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/glm/visual.py b/src/llmcompressor/transformers/tracing/glm/visual.py deleted file mode 100644 index 1a7792e57..000000000 --- a/src/llmcompressor/transformers/tracing/glm/visual.py +++ /dev/null @@ -1,182 +0,0 @@ -# flake8: noqa -# vllm-project: no copyright -import torch -from torch import nn -from argparse import Namespace -import torch.nn.functional as F -from transformers.activations import ACT2FN -import math -from torch.nn import LayerNorm - - -def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True): - if scaling_attention_score: - query_layer = query_layer / math.sqrt(query_layer.shape[-1]) - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - attention_probs = F.softmax(attention_scores, dim=-1) - - context_layer = torch.matmul(attention_probs, value_layer) - return context_layer - - -def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True): - if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score: - # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None. - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_layer, key_layer, value_layer, - attn_mask=None, - dropout_p=0., - is_causal=False - ) - return attn_output - else: - return standard_attention( - query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score - ) - - -class PatchEmbedding(nn.Module): - def __init__(self, config): - super().__init__() - self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, - stride=config.patch_size) - self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) - self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size) - - def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)": - x = self.proj(images) - x = x.flatten(2).transpose(1, 2) - cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) - x = torch.cat((cls_token, x), dim=1) - x += self.position_embedding.weight.unsqueeze(0) - return x - - -class Attention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_heads = config.num_heads - head_dim = config.hidden_size // config.num_heads - self.scale = head_dim ** -0.5 - self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3) - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.output_dropout = torch.nn.Dropout(config.dropout_prob) - - def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)": - B, L, _ = x.shape - qkv = self.query_key_value(x) - qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D - q, k, v = qkv[0], qkv[1], qkv[2] - - out = attention_fn_default( - q, k, v - ) - output = self.dense(out.transpose(1, 2).reshape(B, L, -1)) - output = self.output_dropout(output) - return output - - def attention(self, q, k, v): - attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1)) - attn_weights = attn_weights.softmax(dim=-1) - output = torch.matmul(attn_weights, v) - return output - - -class MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.fc1(x) - x = self.activation_fn(x) - x = self.fc2(x) - return x - - -class TransformerLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = Attention(config) - self.mlp = MLP(config) - self.post_attention_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - attention_input = hidden_states - attention_output = self.input_layernorm(self.attention(attention_input)) - hidden_states = attention_input + attention_output - mlp_input = hidden_states - - # https://github.com/THUDM/GLM-4/issues/350 - mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)).to(mlp_input.device) - output = mlp_input + mlp_output - return output - - -class Transformer(nn.Module): - def __init__(self, config): - super().__init__() - self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)]) - - def forward(self, hidden_states): - for layer_module in self.layers: - hidden_states = layer_module(hidden_states) - return hidden_states - - -class GLU(nn.Module): - def __init__(self, config, in_features): - super().__init__() - self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False) - self.norm1 = nn.LayerNorm(config.hidden_size) - self.act1 = nn.GELU() - self.act2 = nn.functional.silu - self.dense_h_to_4h = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False) - self.gate_proj = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False) - self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=False) - - def forward(self, x): - x = self.linear_proj(x) - x = self.act1(self.norm1(x)) - x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) - x = self.dense_4h_to_h(x) - return x - - -class EVA2CLIPModel(nn.Module): - def __init__(self, config): - super().__init__() - vision_config = Namespace(**config.vision_config) - self.patch_embedding = PatchEmbedding(vision_config) - self.transformer = Transformer(vision_config) - self.linear_proj = GLU(config, in_features=config.hidden_size) - self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2, - stride=2) - self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.scaling_factor = vision_config.scaling_factor - - def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)": - x = self.patch_embedding(images) - x = self.transformer(x) - x = x[:, 1:] - - b, s, h = x.shape - grid_size = int(s ** 0.5) - x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) - x = self.conv(x) - - x = x.flatten(2).transpose(1, 2) - x = self.linear_proj(x) - - # https://github.com/THUDM/GLM-4/issues/350 - boi = self.boi.expand(x.shape[0], -1, -1).to(x.device) - eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device) - x = torch.cat((boi, x, eoi), dim=1) - x = x / self.scaling_factor - return x \ No newline at end of file diff --git a/src/llmcompressor/transformers/utils/data_collator.py b/src/llmcompressor/transformers/utils/data_collator.py index 930b06696..b2dc7c651 100644 --- a/src/llmcompressor/transformers/utils/data_collator.py +++ b/src/llmcompressor/transformers/utils/data_collator.py @@ -5,7 +5,6 @@ "pixtral_data_collator", "llava_data_collator", "qwen2_vl_data_collator", - "glm_data_collator", ] @@ -47,13 +46,3 @@ def qwen2_vl_data_collator(batch): "pixel_values": torch.tensor(batch[0]["pixel_values"]), "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]), } - - -def glm_data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "position_ids": torch.tensor(batch[0]["position_ids"]), - "images": torch.tensor(batch[0]["images"]), - } From f6312d05c81a10e1f0b4d4ed3ce5d885496f8772 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Jan 2025 00:13:49 +0000 Subject: [PATCH 283/285] docstrings, reorder pipeline args Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/gptq/base.py | 4 ++-- src/llmcompressor/pipelines/basic/pipeline.py | 3 +++ src/llmcompressor/pipelines/layer_sequential/helpers.py | 4 ++-- src/llmcompressor/pipelines/layer_sequential/pipeline.py | 6 +++++- src/llmcompressor/pipelines/sequential/pipeline.py | 7 ++++++- 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 136bfdb9a..178fbce4c 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -237,9 +237,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: try: run_sequential( state.model, + state.data.calib, self.sequential_targets, self.ignore, - state.data.calib, ) return True @@ -253,8 +253,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: try: run_layer_sequential( state.model, - self.sequential_targets, state.data.calib, + self.sequential_targets, ) return True diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 0ac0afbc2..c7552a654 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -18,6 +18,9 @@ def run_pipeline(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader through the model. This pipeline is typically used for basic model calibration and, unlike the sequential pipelines, does not propagate compression error when used to calibrate model compression + + :param model: model being calibrated + :param dataloader: loads data for calibration """ model_device = get_execution_device(model) diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py index 762516a1c..91b300cfd 100644 --- a/src/llmcompressor/pipelines/layer_sequential/helpers.py +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -119,8 +119,8 @@ def trigger_early_stop_fn(module, args, kwargs): @dataclass class EarlyStopException(Exception): """ - Note: this is exception different from the exception defined in - llmcompressor.modifiers.utils.pytorch_helpers, and will eventually replace + Note: this exception is different from the exception defined in + llmcompressor.modifiers.utils.pytorch_helpers, and will eventually replace it Attribute names `args` and `kwargs` are reserved for `dataclass` """ diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index a1d38e6f0..f93fd6f2d 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -18,8 +18,8 @@ def run_pipeline( model: torch.nn.Module, - sequential_targets: List[str], dataloader: torch.utils.data.DataLoader, + sequential_targets: List[str], ): """ Run a layer-wise sequential data pipeline according to the following steps: @@ -37,6 +37,10 @@ def run_pipeline( If your model architecture violates these assumptions, consider using the sequential pipeline (see llmcompressor.pipelines.sequential). Architectures which are known to fail these assumptions include GPT-J and most vision language models + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param sequential_targets: patterns which match to the layer modules of the model """ # find layers layers = match_modules(model, sequential_targets) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index fcf1d88b0..647d5761e 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -15,9 +15,9 @@ def run_pipeline( model: torch.nn.Module, + dataloader: torch.utils.data.DataLoader, sequential_targets: List[str], ignore: List[str], - dataloader: torch.utils.data.DataLoader, ): """ Run a sequential data pipeline according to the following steps: @@ -36,6 +36,11 @@ def run_pipeline( In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A model can be made traceable by wrapping the untraceable functions (see llmcompressor.transformers.tracing) + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param sequential_targets: patterns which match to the layer modules of the model + :param ignore: patterns which match to modules which should be ignored by tracing """ # trace subgraphs sample_input = next(iter(dataloader)) From 153a4fa3c4d4831e01219b8d7a901366555bd960 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Jan 2025 00:22:19 +0000 Subject: [PATCH 284/285] correct typos Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/sequential/helpers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 2364e8042..cc2fb0345 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -58,8 +58,8 @@ def trace_subgraphs( :param model: model being traced :param sample_input: inputs whose values will change during execution but whose __len__, __bool__, and __contains__ values are assumed constant across batches - :param sequential_targets: list of patterns specifying sequential targets - :param ignore: list of patterns specifying modules to ignore during tracing + :param sequential_targets: list of patterns matching sequential targets + :param ignore: list of patterns matching modules to ignore during tracing :return: a list of Subgraphs in order of execution """ # find modules @@ -191,8 +191,8 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List :param graph: graph being partitioned :param targets: target modules which will be assigned to disjoint partitions - :return: list of partitions, where each partition is a list of nodes belong to that - partition + :return: list of partitions, where each partition is a list of nodes belonging to + that partition """ assert check_assumption(graph.graph) target_nodes = find_target_nodes(graph, targets) @@ -256,7 +256,7 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap :param model: model which owns the produced Subgraphs :param partitions: list of partitions, where each partition is a list of nodes - belong to that partition + belonging to that partition :return: list of subgraphs in order of execution """ subgraphs = [] @@ -352,7 +352,7 @@ def check_assumption(graph: Graph) -> bool: def match_modules(model: Module, target_names: List[str]) -> Set[Module]: """ - Find modules whose names matach the patterns given by `target_names` + Find modules whose names match the patterns given by `target_names` :param model: model containing submodules to find :param target_names: target patterns to find From 3f9dd7de2cebaed28657792139535fbafff996be Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Jan 2025 18:06:50 +0000 Subject: [PATCH 285/285] code clarity Signed-off-by: Kyle Sayers --- .../pipelines/sequential/helpers.py | 12 ++++++------ .../transformers/tracing/mistral.py | 19 +++++++++---------- .../transformers/tracing/mllama.py | 3 +-- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index cc2fb0345..4945ba01e 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -194,7 +194,7 @@ def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List :return: list of partitions, where each partition is a list of nodes belonging to that partition """ - assert check_assumption(graph.graph) + assert graph_is_well_formed(graph.graph) target_nodes = find_target_nodes(graph, targets) partitions: List[List[Node]] = [[]] @@ -301,7 +301,7 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap ) ) - assert check_assumption(graph) + assert graph_is_well_formed(graph) return subgraphs @@ -325,13 +325,13 @@ def trace_consumed_names(subgraphs: List[Subgraph]): raise ValueError(f"Could not find input name {input_name} in subgraphs") -def check_assumption(graph: Graph) -> bool: +def graph_is_well_formed(graph: Graph) -> bool: """ - Checks that a graph is not malformed + A graph is well formed if and only if + `nodeA in NodeB.users <=> nodeB in Node.A.all_input_nodes` :param graph: graph being checked - :return: True if node.users and node.all_input_nodes have bidirectional - relationships, False otherwise + :return: True if the graph is well formed, False otherwise """ for node in graph.nodes: for user in node.users: diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py index e6d96d912..3c9102b23 100644 --- a/src/llmcompressor/transformers/tracing/mistral.py +++ b/src/llmcompressor/transformers/tracing/mistral.py @@ -32,7 +32,6 @@ ) # TRACING: imports -from torch.fx import wrap from transformers.models.mistral.modeling_mistral import ( MistralPreTrainedModel, MistralModel, @@ -46,7 +45,7 @@ # TRACING: This function is untracable -@wrap +@torch.fx.wrap def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -195,11 +194,11 @@ def _update_causal_mask( return causal_mask -# TRACING: Must use MistralModel +# TRACING: Must use MistralModel with wrapped function class MistralForCausalLM(MistralForCausalLM): def __init__(self, config): super(MistralPreTrainedModel, self).__init__(config) - # TRACING: Must use MistralModel + # TRACING: Must use MistralModel with wrapped function self.model = MistralModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -208,12 +207,12 @@ def __init__(self, config): self.post_init() -# TRACING: Must use MistralModel +# TRACING: Must use MistralModel with wrapped function class MistralForSequenceClassification(MistralForSequenceClassification): def __init__(self, config): super(MistralPreTrainedModel, self).__init__(config) self.num_labels = config.num_labels - # TRACING: Must use MistralModel + # TRACING: Must use MistralModel with wrapped function self.model = MistralModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) @@ -221,12 +220,12 @@ def __init__(self, config): self.post_init() -# TRACING: Must use MistralModel +# TRACING: Must use MistralModel with wrapped function class MistralForTokenClassification(MistralForTokenClassification): def __init__(self, config): super(MistralPreTrainedModel, self).__init__(config) self.num_labels = config.num_labels - # TRACING: Must use MistralModel + # TRACING: Must use MistralModel with wrapped function self.model = MistralModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout @@ -240,11 +239,11 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() -# TRACING: Must use MistralModel +# TRACING: Must use MistralModel with wrapped function class MistralForQuestionAnswering(MistralForQuestionAnswering): def __init__(self, config): super(MistralPreTrainedModel, self).__init__(config) - # TRACING: Must use MistralModel + # TRACING: Must use MistralModel with wrapped function self.model = MistralModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) diff --git a/src/llmcompressor/transformers/tracing/mllama.py b/src/llmcompressor/transformers/tracing/mllama.py index 395567afd..8b65b179c 100644 --- a/src/llmcompressor/transformers/tracing/mllama.py +++ b/src/llmcompressor/transformers/tracing/mllama.py @@ -28,7 +28,6 @@ ) # TRACING: imports -from torch.fx import wrap from transformers.models.mllama.modeling_mllama import ( MLLAMA_START_DOCSTRING, MllamaForConditionalGeneration, @@ -38,7 +37,7 @@ # TRACING: This function is not traceable -@wrap +@torch.fx.wrap def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, num_vision_tokens: int,