diff --git a/src/sparseml/utils/pytorch/__init__.py b/src/sparseml/utils/pytorch/__init__.py
index 10c86104af1..05bd7af1510 100644
--- a/src/sparseml/utils/pytorch/__init__.py
+++ b/src/sparseml/utils/pytorch/__init__.py
@@ -14,4 +14,5 @@
 
 # flake8: noqa
 
+from .converters import *
 from .module import *
diff --git a/src/sparseml/utils/pytorch/converters/__init__.py b/src/sparseml/utils/pytorch/converters/__init__.py
new file mode 100644
index 00000000000..87c7a2fed59
--- /dev/null
+++ b/src/sparseml/utils/pytorch/converters/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# flake8: noqa
+
+
+from .converters import *
diff --git a/src/sparseml/utils/pytorch/converters/converters.py b/src/sparseml/utils/pytorch/converters/converters.py
new file mode 100644
index 00000000000..283136a3bb9
--- /dev/null
+++ b/src/sparseml/utils/pytorch/converters/converters.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import shutil
+from abc import ABC
+from pathlib import Path
+from typing import Callable, Dict, Iterable, Union
+
+import torch
+
+from safetensors.torch import save_file
+from sparseml.pytorch.model_load.helpers import load_safetensors_state_dict
+from sparseml.utils.pytorch.converters.transformations import (
+    transform_autogptq_weights_and_reshape_tensors,
+    transform_exllama_names,
+)
+
+
+StateDictType = Union[Dict[str, torch.Tensor], str, Path]
+TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]
+_LOGGER: logging.Logger = logging.getLogger(__name__)
+
+
+class BaseConverter(ABC):
+    @classmethod
+    def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType:
+        """
+        Applies transformations to the state_dict
+
+        :param state_dict: The state_dict to apply transformations to
+        :param kwargs: Additional arguments to pass to the transformations
+        :return: The transformed state_dict
+        """
+        _LOGGER.info("Applying transformations...")
+        new_state_dict = copy.copy(state_dict)
+        for transformation in cls.transformations():
+            new_state_dict = transformation(new_state_dict, **kwargs)
+        return new_state_dict
+
+    @classmethod
+    def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str:
+        """
+        Convert a .safetensors file or directory of .safetensors files, applying
+        transformations to the state_dict and saving the new state_dict to a new
+        directory
+
+        :param filepath: The file path to the .safetensors file or directory
+            containing .safetensors files to convert
+        :param save_dir: The directory to save the converted state_dict to
+        :return: The directory where the converted state_dict was saved
+        """
+        _validate_safetensors_file_path(filepath)
+
+        filepath_: Path = Path(filepath)
+        if not save_dir:
+            save_dir = "compressed_tensors_model"
+
+        save_dir_: Path = Path(save_dir)
+        save_dir_.mkdir(exist_ok=True, parents=True)
+
+        metadata = {"format": "pt", "source": "Created by SparseML"}
+
+        # transform and save the state_dict
+        if filepath_.is_dir():
+            for file in filepath_.glob("*.safetensors"):
+                _LOGGER.info(f"Loading file: {file}")
+                state_dict: StateDictType = load_safetensors_state_dict(file)
+                new_state_dict = cls.translate(state_dict=state_dict)
+                save_file(
+                    new_state_dict, filename=save_dir_ / file.name, metadata=metadata
+                )
+            _copy_non_safetensor_files_(filepath_, save_dir_)
+            _update_quantization_config(filepath_, save_dir_)
+
+        elif filepath_.is_file():
+            state_dict: StateDictType = load_safetensors_state_dict(filepath)
+            new_state_dict = cls.translate(state_dict=state_dict)
+            save_file(
+                new_state_dict, save_path=save_dir_ / filepath_.name, metadata=metadata
+            )
+
+        return str(save_dir_)
+
+    @classmethod
+    def transformations(cls) -> Iterable[TransformationType]:
+        """
+        Returns an iterable of transformations that are applied in the converter,
+        each transformation should be a callable that takes a state_dict and returns
+        a transformed state_dict
+        """
+        raise NotImplementedError()
+
+
+class ExllamaToCompressedTensorConverter(BaseConverter):
+    """
+    A converter that applies transformations to the state_dict of a autogptq
+    quantized model to convert it to a compressed tensor model, which can be
+    loaded by the SparseAutoModel classes
+    """
+
+    @classmethod
+    def transformations(cls):
+        return (transform_autogptq_weights_and_reshape_tensors, transform_exllama_names)
+
+
+def _validate_safetensors_file_path(filepath: str):
+    """
+    Given a file path, it is valid if:
+        - The file exists
+        - The file is either a single .safetensors file or a
+            directory containing .safetensors files
+
+    :param filepath: A string file path to validate
+    """
+
+    filepath_: Path = Path(filepath)
+
+    if not filepath_.exists():
+        raise FileNotFoundError(f"File not found: {filepath}")
+
+    if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")):
+        raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}")
+
+    if filepath_.is_file() and not filepath_.suffix == ".safetensors":
+        raise ValueError(f"File must be a .safetensors file: {filepath}")
+
+
+def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path):
+    """
+    A helper function to copy all auxillary files in a directory that are
+    not .safetensors files, for example (config.json, recipe.yaml, ...)
+
+    :param source_dir: The directory to copy files from
+    :param dest_dir: The directory to copy files to
+    """
+    for file in source_dir.glob("*"):
+        if file.suffix != ".safetensors":
+            _LOGGER.info(f"Copying file: {file} to {dest_dir}")
+            shutil.copy(file, dest_dir / file.name)
+
+
+def _update_quantization_config(source_dir: Path, dest_dir: Path):
+    """
+    Updates config.json file in the destination directory by removing the
+    quantization_config attribute
+
+    :param source_dir: The directory containing the original config.json file
+    :param dest_dir: The directory to save the updated config.json file
+    """
+    from sparseml.transformers import SparseAutoConfig
+
+    config = SparseAutoConfig.from_pretrained(source_dir)
+
+    if hasattr(config, "quantization_config"):
+        _LOGGER.info("Updating quantization config...")
+        delattr(config, "quantization_config")
+    config.save_pretrained(dest_dir)
diff --git a/src/sparseml/utils/pytorch/converters/transformations.py b/src/sparseml/utils/pytorch/converters/transformations.py
new file mode 100644
index 00000000000..9a96a847b87
--- /dev/null
+++ b/src/sparseml/utils/pytorch/converters/transformations.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# flake8: noqa: F821
+
+import functools
+import logging
+from typing import Dict
+
+import numpy
+import numpy as np
+import torch
+from torch import Tensor
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def _log_transformation(func):
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        _LOGGER.info("Applying transformation: %s", func.__name__.upper())
+        return_value = func(*args, **kwargs)
+        _LOGGER.info("Transformation: %s complete", func.__name__.upper())
+        return return_value
+
+    return wrapper
+
+
+def is_gptq_quantization_target(key: str) -> bool:
+    """
+    Assumes self_attn and mlp are the only quantization targets
+    in model layers of the state_dict.
+    :param key: The key of the state_dict
+    :return: True if the key is a quantization target, False otherwise
+    """
+    return "model.layers" in key and ("self_attn" in key or "mlp" in key)
+
+
+@_log_transformation
+def transform_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
+    """
+    Transforms the exallama state_dict keys to be compatible with
+    SparseAutoModel classes.
+
+    The renames include:
+        - scales -> weight_fake_quant.scale
+        - qzeros -> weight_fake_quant.zero_point
+        - qweight -> weight
+
+    Note: does not transforms the actual tensor values
+
+    :pre-condition: The state_dict should be for a quantized model
+    :pre-condition: Targets only the weights of the self_attn and mlp nodes
+    :param state_dict: The quantized state_dict to be transformed
+    :return: The transformed state_dict
+    """
+
+    name_map: Dict[str, str] = {
+        ".scales": ".weight_fake_quant.scale",
+        ".qzeros": ".weight_fake_quant.zero_point",
+        ".qweight": ".weight",
+    }
+
+    updated_state_dict = {}
+    for key, tensor in state_dict.items():
+        if any(key.endswith(target_suffix := suffix) for suffix in name_map):
+            updated_key = key.replace(target_suffix, name_map[target_suffix])
+            updated_state_dict[updated_key] = tensor
+        else:
+            updated_state_dict[key] = tensor
+    return updated_state_dict
+
+
+@_log_transformation
+def transform_autogptq_weights_and_reshape_tensors(
+    state_dict: Dict[str, Tensor]
+) -> Dict[str, Tensor]:
+    """
+    Tranforms weights into their required shapes and types for Exllama
+    to CompressedTensors conversion
+
+    The transformations include:
+        - Unpack ad dequantize the weight tensor using the scales, zeros, and g_idx tensors
+        - Squeeze the scales tensor to [x] from [1, x]
+
+    :pre-condition: The state_dict should be for a quantized model
+    :pre-condition: The state_dict should have the bias and g_idx tensors added
+
+    :param state_dict: The state_dict to be transformed
+    :return: The transformed state_dict, with repacked and reshaped tensors
+    """
+
+    transformed_state_dict: Dict[str, Tensor] = {}
+
+    # auxillary dict to store transformed weights
+    transformed_weights_dict: Dict[str, Tensor] = {}
+
+    # quantize qweights before scales, and qzeros
+    # because the ordering in which tensors are fetched
+    # is not guaranteed by our implementation
+    for key, tensor in state_dict.items():
+        if is_gptq_quantization_target(key) and key.endswith(".qweight"):
+            # quantize the weight tensor
+            scales = state_dict[key.replace("qweight", "scales")]
+            qzeros = state_dict[key.replace("qweight", "qzeros")]
+            g_idx = state_dict[key.replace("qweight", "g_idx")]
+
+            zeros = unpack_zeros(qzeros)
+            qweight = unpack_int32_into_fp32(
+                qweight=tensor,
+                scales=scales,
+                zeros=zeros,
+                g_idx=g_idx,
+            )
+            transformed_weights_dict[key] = qweight
+
+    # transform scales
+    for key, tensor in state_dict.items():
+        if is_gptq_quantization_target(key) and key.endswith(".scales"):
+            # scales [1, x] should be reshaped to [x]
+            scales = tensor.squeeze(0)
+            transformed_state_dict[key] = scales
+        else:
+            transformed_state_dict[key] = tensor
+
+    # overwrite old weights with the new quantized weights
+    transformed_state_dict.update(transformed_weights_dict)
+
+    # auxillary weights_dict not needed anymore
+    del transformed_weights_dict
+
+    return transformed_state_dict
+
+
+def unpack_zeros(qzeros):
+    """
+    Unpack the quantized zero points tensor from 32 bit integers into 4 bit integers.
+
+    :param qzeros: The quantized zero points tensor of int32 dtype and shape [1, 8x]
+    """
+    bits = 4
+    qzeros = qzeros.numpy().astype(np.uint32)
+    intzeros = np.zeros(
+        (qzeros.shape[0], qzeros.shape[1] * 32 // bits), dtype=np.uint32
+    )
+
+    i = 0
+    col = 0
+    while col < intzeros.shape[1]:
+        if bits in [4]:
+            for j in range(i, min(i + (32 // bits), intzeros.shape[1])):
+                intzeros[:, j] = (qzeros[:, col] >> (bits * (j - i))) & 0xF
+            i += 32 // bits
+            col += 1
+        else:
+            raise NotImplementedError("Only 4 bits are supported.")
+
+    intzeros = intzeros.astype(np.int32)
+    intzeros = torch.from_numpy(intzeros)
+
+    return intzeros
+
+
+def unpack_int32_into_fp32(
+    qweight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor
+) -> Tensor:
+    """
+    Unpack the quantized weight tensor from 32 bit integers into 4 bit integers,
+    and then dequantize them using the scales, zeros, and g_idx tensors.
+
+    :param qweight: The quantized weight tensor of int32 dtype and shape [x, y]
+    :param scales: The scales tensor
+    :param zeros: The zero points tensor
+    :param g_idx: The group index tensor
+    :return: The dequantized weight tensor of shape [x, 8y]
+    """
+    bits = 4
+    qweight = qweight.numpy().astype(numpy.uint32)
+    intweight = numpy.zeros(
+        (qweight.shape[0] * 32 // bits, qweight.shape[1]), dtype=numpy.uint32
+    )
+
+    i = 0
+    row = 0
+    while row < intweight.shape[0]:
+        if bits in [4]:
+            for j in range(i, min(i + (32 // bits), intweight.shape[0])):
+                intweight[j] = (qweight[row] >> (bits * (j - i))) & 0xF
+            i += 32 // bits
+            row += 1
+        else:
+            raise NotImplementedError("Only 4 bits are supported.")
+
+    intweight = torch.from_numpy(intweight.astype(numpy.int32))
+    intweight = intweight.t().contiguous()
+
+    scales = scales.t().contiguous()
+    zeros = zeros.t().contiguous()
+    scale_zeros = zeros * scales
+    scales = scales.clone().half()
+
+    weight = []
+    infeatures = intweight.shape[1]
+    for idx in range(infeatures):
+        weight.append(
+            (
+                intweight[:, idx].float() * scales[:, g_idx[idx]]
+                - scale_zeros[:, g_idx[idx]]
+            )[:, None]
+        )
+    weight = torch.cat(weight, dim=1)
+
+    return weight