forked from ml-explore/mlx-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* probably approximatelly correct CLIPTextEncoder * implemented CLIPEncoderLayer as built-in nn.TransformerEncoderLayer * replaced embedding layer with simple matrix * implemented ViT * added ViT tests * fixed tests * added pooler_output for text * implemented complete CLIPModel * implemented init * implemented convert.py and from_pretrained * fixed some minor bugs and added the README.md * removed tokenizer unused comments * removed unused deps * updated ACKNOWLEDGEMENTS.md * Feat: Image Processor for CLIP (#1) @nkasmanoff: * clip image processor * added example usage * refactored image preprocessing * deleted unused image_config.py * removed preprocessing port * added dependency to mlx-data * fixed attribution and moved photos to assets * implemented a simple port of CLIPImageProcessor * review changes * PR review changes * renamed too verbose arg * updated README.md * nits in readme / conversion * simplify some stuff, remove unneeded inits * remove more init stuff * more simplify * make test a unit test * update main readme * readme nits --------- Co-authored-by: Noah Kasmanoff <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
- Loading branch information
1 parent
ba3a935
commit 9435821
Showing
14 changed files
with
890 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
mlx_model/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# CLIP | ||
|
||
An example of OpenAI's CLIP in MLX. The CLIP (contrastive language-image | ||
pre-training) model embeds images and text in the same space.[^1] | ||
|
||
### Setup | ||
|
||
Install the dependencies: | ||
|
||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Next, download a CLIP model from Hugging Face and convert it to MLX. The | ||
default model is | ||
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32). | ||
|
||
``` | ||
python convert.py | ||
``` | ||
|
||
The script will by default download the model and configuration files to the | ||
directory ``mlx_model/``. | ||
|
||
### Run | ||
|
||
You can use the CLIP model to embed images and text. | ||
|
||
```python | ||
from PIL import Image | ||
import clip | ||
|
||
model, tokenizer, img_processor = clip.load("mlx_model") | ||
inputs = { | ||
"input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]), | ||
"pixel_values": img_processor( | ||
[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")] | ||
), | ||
} | ||
output = model(**inputs) | ||
|
||
# Get text and image embeddings: | ||
text_embeds = output.text_embeds | ||
image_embeds = output.image_embeds | ||
``` | ||
|
||
Run the above example with `python clip.py`. | ||
|
||
To embed only images or only the text, pass only the ``input_ids`` or | ||
``pixel_values``, respectively. | ||
|
||
This example re-implements minimal image preprocessing and tokenization to reduce | ||
dependencies. For additional preprocessing functionality, you can use | ||
``transformers``. The file `hf_preproc.py` has an example. | ||
|
||
MLX CLIP has been tested and works with the following Hugging Face repos: | ||
|
||
- [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) | ||
- [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) | ||
|
||
You can run the tests with: | ||
|
||
```shell | ||
python test.py | ||
``` | ||
|
||
To test new models, update the `MLX_PATH` and `HF_PATH` in `test.py`. | ||
|
||
### Attribution | ||
|
||
- `assets/cat.jpeg` is a "Cat" by London's, licensed under CC BY-SA 2.0. | ||
- `assets/dog.jpeg` is a "Happy Dog" by tedmurphy, licensed under CC BY 2.0. | ||
|
||
[^1]: Refer to the original paper [Learning Transferable Visual Models From | ||
Natural Language Supervision ](https://arxiv.org/abs/2103.00020) or [blog | ||
post](https://openai.com/research/clip) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Tuple | ||
|
||
from image_processor import CLIPImageProcessor | ||
from model import CLIPModel | ||
from tokenizer import CLIPTokenizer | ||
|
||
|
||
def load(model_dir: str) -> Tuple[CLIPModel, CLIPTokenizer, CLIPImageProcessor]: | ||
model = CLIPModel.from_pretrained(model_dir) | ||
tokenizer = CLIPTokenizer.from_pretrained(model_dir) | ||
img_processor = CLIPImageProcessor.from_pretrained(model_dir) | ||
return model, tokenizer, img_processor | ||
|
||
|
||
if __name__ == "__main__": | ||
from PIL import Image | ||
|
||
model, tokenizer, img_processor = load("mlx_model") | ||
inputs = { | ||
"input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]), | ||
"pixel_values": img_processor( | ||
[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")] | ||
), | ||
} | ||
output = model(**inputs) | ||
|
||
# Get text and image embeddings: | ||
text_embeds = output.text_embeds | ||
image_embeds = output.image_embeds | ||
print("Text embeddings shape:", text_embeds.shape) | ||
print("Image embeddings shape:", image_embeds.shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
import argparse | ||
import shutil | ||
from pathlib import Path | ||
from typing import Tuple | ||
|
||
import mlx.core as mx | ||
import torch | ||
from huggingface_hub import snapshot_download | ||
|
||
|
||
def get_model_path(path_or_hf_repo: str) -> Path: | ||
model_path = Path(path_or_hf_repo) | ||
if not model_path.exists(): | ||
model_path = Path( | ||
snapshot_download( | ||
repo_id=path_or_hf_repo, | ||
allow_patterns=[ | ||
"*.bin", | ||
"*.json", | ||
"*.txt", | ||
], | ||
) | ||
) | ||
return model_path | ||
|
||
|
||
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: | ||
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss | ||
a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype)) | ||
return mx.array(a.numpy(), getattr(mx, dtype)) | ||
|
||
|
||
def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]: | ||
key = key.replace("embeddings.", "") | ||
key = key.replace("encoder.", "") | ||
key = key.replace("position_embedding.weight", "position_embedding") | ||
|
||
# Map attention layers | ||
if "self_attn." in key: | ||
key = key.replace("self_attn.", "attention.") | ||
if "q_proj." in key: | ||
key = key.replace("q_proj.", "query_proj.") | ||
if "k_proj." in key: | ||
key = key.replace("k_proj.", "key_proj.") | ||
if "v_proj." in key: | ||
key = key.replace("v_proj.", "value_proj.") | ||
if "layer_norm1." in key: | ||
key = key.replace("layer_norm1.", "ln1.") | ||
if "layer_norm2." in key: | ||
key = key.replace("layer_norm2.", "ln2.") | ||
# Map ffn layers | ||
if "mlp.fc1" in key: | ||
key = key.replace("mlp.fc1", "linear1") | ||
if "mlp.fc2" in key: | ||
key = key.replace("mlp.fc2", "linear2") | ||
# Fix layernorm typo | ||
if "pre_layrnorm" in key: | ||
# Fix typo in weights :) | ||
key = key.replace("pre_layrnorm", "pre_layernorm") | ||
if "patch_embedding.weight" in key: | ||
# Initially, value: [out_channels, in_channels, kH, KW]. | ||
# We want [out_channels, kH, KW, in_channels] | ||
value = value.permute(0, 2, 3, 1) | ||
return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", ""))) | ||
|
||
|
||
def should_keep_weight(key: str): | ||
return not ("position_ids" in key) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Download and Convert (OpenAI) CLIP weights to MLX" | ||
) | ||
parser.add_argument( | ||
"--hf-repo", | ||
type=str, | ||
default="openai/clip-vit-base-patch32", | ||
help="Hugging Face repository name.", | ||
) | ||
parser.add_argument( | ||
"--mlx-path", | ||
type=str, | ||
default="mlx_model", | ||
help="Path to save the MLX model.", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
torch_path = get_model_path(args.hf_repo) | ||
mlx_path = Path(args.mlx_path) | ||
mlx_path.mkdir(parents=True, exist_ok=True) | ||
|
||
print("[INFO] Loading") | ||
torch_weights = torch.load(torch_path / "pytorch_model.bin") | ||
print("[INFO] Converting") | ||
mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items()) | ||
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)} | ||
print("[INFO] Saving") | ||
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights) | ||
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]: | ||
shutil.copyfile( | ||
str(torch_path / f"{fn}"), | ||
str(mlx_path / f"{fn}"), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import mlx.core as mx | ||
import transformers | ||
from PIL import Image | ||
|
||
import clip | ||
|
||
hf_model = "openai/clip-vit-base-patch32" | ||
mlx_model = "mlx_model" | ||
|
||
model, *_ = clip.load(mlx_model) | ||
processor = transformers.CLIPProcessor.from_pretrained(hf_model) | ||
|
||
inputs = processor( | ||
text=["a photo of a cat", "a photo of a dog"], | ||
images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")], | ||
return_tensors="np", | ||
) | ||
|
||
out = model( | ||
input_ids=mx.array(inputs.input_ids), | ||
pixel_values=mx.array(inputs.pixel_values).transpose((0, 2, 3, 1)), | ||
return_loss=True, | ||
) | ||
|
||
print("text embeddings:") | ||
print(out.text_embeds) | ||
print("image embeddings:") | ||
print(out.image_embeds) | ||
print(f"CLIP loss: {out.loss.item():.3f}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
import json | ||
from pathlib import Path | ||
from typing import List, Tuple | ||
|
||
import mlx.core as mx | ||
import numpy as np | ||
from PIL.Image import Image | ||
|
||
|
||
class CLIPImageProcessor: | ||
""" | ||
A simple port of | ||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
crop_size: int = 224, | ||
do_center_crop: bool = True, | ||
do_normalize: bool = True, | ||
do_resize: bool = True, | ||
image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], | ||
image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], | ||
size: int = 224, | ||
**kwargs | ||
) -> None: | ||
self.crop_size = crop_size | ||
self.do_center_crop = do_center_crop | ||
self.do_normalize = do_normalize | ||
self.do_resize = do_resize | ||
self.image_mean = mx.array(image_mean) | ||
self.image_std = mx.array(image_std) | ||
self.size = size | ||
|
||
def __call__(self, images: List[Image]) -> mx.array: | ||
return mx.concatenate( | ||
[self._preprocess(image)[None] for image in images], axis=0 | ||
) | ||
|
||
def _preprocess(self, image: Image) -> mx.array: | ||
if self.do_resize: | ||
image = resize(image, self.size) | ||
if self.do_center_crop: | ||
image = center_crop(image, (self.crop_size, self.crop_size)) | ||
image = mx.array(np.array(image)) | ||
image = rescale(image) | ||
if self.do_normalize: | ||
image = normalize(image, self.image_mean, self.image_std) | ||
return image | ||
|
||
@staticmethod | ||
def from_pretrained(path: str): | ||
path = Path(path) | ||
with open(path / "preprocessor_config.json", encoding="utf-8") as f: | ||
config = json.load(f) | ||
return CLIPImageProcessor(**config) | ||
|
||
|
||
def resize(image: Image, short_size: int) -> Image: | ||
""" | ||
Resize so small size to short_size | ||
""" | ||
width, height = image.size | ||
short = min(width, height) | ||
long = max(width, height) | ||
if short == short_size: | ||
return image | ||
new_short = short_size | ||
new_long = int(short_size * long / short) | ||
new_size = (new_short, new_long) if width <= height else (new_long, new_short) | ||
return image.resize(new_size) | ||
|
||
|
||
def center_crop(image: Image, size: Tuple[int, int]) -> Image: | ||
if size[0] % 2 != 0 or size[1] % 2 != 0: | ||
raise ValueError("Only even crop sizes supported.") | ||
original_width, original_height = image.size | ||
crop_height, crop_width = size | ||
top = (original_height - crop_height) // 2 | ||
bottom = top + crop_height | ||
left = (original_width - crop_width) // 2 | ||
right = left + crop_width | ||
return image.crop((left, top, right, bottom)) | ||
|
||
|
||
def rescale(image: mx.array) -> mx.array: | ||
return image.astype(mx.float32) * (1 / 255.0) | ||
|
||
|
||
def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array: | ||
return (image - mean) / std |
Oops, something went wrong.