Skip to content

Commit

Permalink
Merge branch 'main' into clear-ml
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm authored Mar 26, 2024
2 parents 91246fe + ae121a5 commit 5c9af24
Show file tree
Hide file tree
Showing 12 changed files with 523 additions and 55 deletions.
80 changes: 80 additions & 0 deletions .github/workflows/build-container.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
name: Build Docker Container
on:
pull_request:
types: [opened, synchronize, reopened]
branches:
- main
- 'release/[0-9]+.[0-9]+'
push:
branches:
- 'main'
release:
types: [created, published]
schedule:
- cron: '0 2 * * *'

# TODO: docker containers created through a release cut vs PR to the release branch
# will be pushed to different locations (i.e one will be sparseml the other will be test-sparseml).
# These containers rely on the new internal pypi server being enabled. Once enabled,
# this workflow can be expanded to make this distinction.
env:
RELEASE: ${{ github.event_name =='release' || (startsWith(github.base_ref, 'release/') && github.event_name == 'pull_request')}}
DEV: ${{ github.base_ref == 'main' && github.event_name == 'pull_request'}}
NAME: ${{ github.event.number }}

permissions:
contents: read
packages: write

jobs:
build-container:
name: Build sparseml container
runs-on: ubuntu-20.04
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 1
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v2
with:
buildkitd-flags: --debug
- name: Get current date
id: date
run: echo "::set-output name=date::$(date +'%Y%m%d')"
- name: Get the current version
if: ${{ env.RELEASE == 'true' }}
id: version
run: echo "::set-output name=version::$(echo ${{ github.base_ref }} | cut -c 9-15)"
- name: Login to Github Packages
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build Dev Docker Container
if: ${{ env.DEV == 'true' }}
uses: docker/build-push-action@v4
with:
context: ./docker/containers/docker_dev
build-args: |
BRANCH=${{github.head_ref}}
push: true
tags: ghcr.io/neuralmagic/sparseml-dev:${{ env.NAME }}
- name: Build Release Docker Container
if: ${{ env.RELEASE == 'true' }}
uses: docker/build-push-action@v4
with:
context: ./docker/containers/docker_release
build-args: |
VERSION=${{ steps.version.outputs.version }}
push: true
tags: ghcr.io/neuralmagic/test-sparseml:latest, ghcr.io/neuralmagic/test-sparseml:${{ steps.version.outputs.version }}
- name: Build Nightly Docker Container
if: ${{ env.DEV == 'false' && env.RELEASE == 'false'}}
uses: docker/build-push-action@v4
with:
context: ./docker/containers/docker_nightly
push: true
tags: ghcr.io/neuralmagic/test-sparseml-nightly:latest, ghcr.io/neuralmagic/test-sparseml-nightly:${{ steps.date.outputs.date }}
25 changes: 25 additions & 0 deletions docker/containers/docker_dev/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
ARG SOURCE=ghcr.io/neuralmagic/cuda-python3.10

ARG TORCH_VERSION=2.1.2
ARG TORCHVISION_VERSION=0.16.2
ARG CUDA=121
ARG BRANCH

FROM $SOURCE

ARG BRANCH

RUN python3.10 -m pip install --upgrade pip \
&& python3.10 -m pip install --upgrade setuptools

ARG CUDA
ARG TORCH_VERSION
ARG TORCHVISION_VERSION

RUN python3.10 -m pip install torch==${TORCH_VERSION}+cu${CUDA} torchvision==${TORCHVISION_VERSION}+cu${CUDA} -f https://download.pytorch.org/whl/torch_stable.html \
&& git clone https://github.com/neuralmagic/sparseml.git --depth 1 --single-branch -b ${BRANCH} \
&& python3.10 -m pip install -e "./sparseml[dev]"

HEALTHCHECK CMD python3.10 -c 'import sparseml'
RUN python3.10 -m pip list | grep sparseml
CMD bash
21 changes: 21 additions & 0 deletions docker/containers/docker_nightly/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
ARG SOURCE=ghcr.io/neuralmagic/cuda-python3.10

ARG TORCH_VERSION=2.1.2
ARG TORCHVISION_VERSION=0.16.2
ARG CUDA=121

FROM $SOURCE

RUN python3.10 -m pip install --upgrade pip \
&& python3.10 -m pip install --upgrade setuptools

ARG CUDA
ARG TORCH_VERSION
ARG TORCHVISION_VERSION

RUN python3.10 -m pip install torch==${TORCH_VERSION}+cu${CUDA} torchvision==${TORCHVISION_VERSION}+cu${CUDA} -f https://download.pytorch.org/whl/torch_stable.html \
&& python3.10 -m pip install --no-cache-dir "sparseml-nightly[onnxruntime,torchvision,transformers,yolov5,ultralytics]"

HEALTHCHECK CMD python3.10 -c 'import sparseml'
RUN python3.10 -m pip list | grep sparseml
CMD bash
24 changes: 24 additions & 0 deletions docker/containers/docker_release/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
ARG SOURCE=ghcr.io/neuralmagic/cuda-python3.10

ARG TORCH_VERSION=2.1.2
ARG TORCHVISION_VERSION=0.16.2
ARG CUDA=121
ARG VERSION

FROM $SOURCE

ARG VERSION

ARG CUDA
ARG TORCH_VERSION
ARG TORCHVISION_VERSION

RUN python3.10 -m pip install --upgrade pip \
&& python3.10 -m pip install --upgrade setuptools

RUN python3.10 -m pip install torch==${TORCH_VERSION}+cu${CUDA} torchvision==${TORCHVISION_VERSION}+cu${CUDA} -f https://download.pytorch.org/whl/torch_stable.html \
&& python3.10 -m pip install --no-cache-dir "sparseml[onnxruntime,torchvision,transformers,yolov5,ultralytics]==$VERSION"

HEALTHCHECK CMD python3.10 -c 'import sparseml'
RUN python3.10 -m pip list | grep sparseml
CMD bash
55 changes: 35 additions & 20 deletions src/sparseml/transformers/compression/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,47 @@ needed for decompression in the compressed state_dict:
```python
from sparseml.transformers import SparseAutoModelForCausalLM
from sparseml.transformers.compression import BitmaskConfig, BitmaskCompressor
from safetensors import safe_open
import os
from sparseml.utils.pytorch.utils import measure_cuda_memory
from tqdm import tqdm
import torch

MODEL_PATH = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned50.oneshot"
OUTPUT_PATH = "./test_compress_output"

model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH)
torch.cuda.set_device(0)
with measure_cuda_memory() as m:
model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0")
print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

sparsity_config = BitmaskConfig()
compressor = BitmaskCompressor(config=sparsity_config)

model_state_dict = model.state_dict()
sparse_state_dict = compressor.compress(model_state_dict)


model.save_pretrained(OUTPUT_PATH, safe_serialization=True, state_dict=sparse_state_dict)

safetensors_path = os.path.join(OUTPUT_PATH, "model-00001-of-00002.safetensors")
with safe_open(safetensors_path, framework="pt", device=0) as f:
test_name = "model.layers.4.self_attn.k_proj.weight"
bitmask = f.get_tensor(test_name + ".bitmask")
shape = f.get_tensor(test_name + ".shape")
values = f.get_tensor(test_name + ".compressed")
row_offsets = f.get_tensor(test_name + ".row_offsets")
print(f"bitmask: {bitmask}")
print(f"shape: {shape}")
print(f"values: {values}")
print(f"row offsets: {row_offsets}")
# compresses the model using Bitmask compression
with measure_cuda_memory() as m:
model_state_dict = model.state_dict()
sparse_state_dict = compressor.compress(model_state_dict)

# save the compressed model
model.save_pretrained(
OUTPUT_PATH,
safe_serialization=True,
state_dict=sparse_state_dict
)

print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

# use the dense state dict to reload the model
torch.cuda.set_device(1)
with measure_cuda_memory() as m:
model_again = SparseAutoModelForCausalLM.from_pretrained(
OUTPUT_PATH,
device_map="cuda:1"
)

#returns iterator
dense_state_dict = compressor.decompress(OUTPUT_PATH)
for name, data in tqdm(dense_state_dict, desc="Decompressing model"):
BitmaskCompressor.replace_layer(name, data, model_again)

print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")
```
25 changes: 21 additions & 4 deletions src/sparseml/transformers/compression/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
import operator
from typing import Dict, Generator

from torch import Tensor
from torch.nn import Module, Parameter

from sparseml.transformers.compression.config import CompressionConfig
from sparseml.utils.pytorch.module import set_layer
from sparsezoo.utils.registry import RegistryMixin


Expand All @@ -42,11 +45,25 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
raise NotImplementedError()

def decompress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def decompress(self, model_path: str) -> Generator:
"""
Uncompresses a compressed state dict back to dense
Reads a compressed state dict located at model_path and returns a
generator for sequentially decompressing back to a dense state dict
:param model_state: state dict of uncompressed model
:param model_path: path to compressed safetensors model
:return: compressed state dict
"""
raise NotImplementedError()

@staticmethod
def replace_layer(param_name: str, data: Tensor, model: Module):
"""
Overwrites a parameterized layer with a new tensor, maintaining the device of
the original parameter
:param param_name: name of parameterized layer to replace
:param data: tensor to insert into model
:param model: pytorch model to insert data into
"""
model_device = operator.attrgetter(param_name)(model).device
set_layer(param_name, Parameter(data.to(model_device)), model)
Loading

0 comments on commit 5c9af24

Please sign in to comment.