Skip to content

Commit

Permalink
docs: add code example and introduction about image_classification. (#59
Browse files Browse the repository at this point in the history
)

* chore: add tools in docker image

* docs: add code example about image_classification.

* docs: update README.md
  • Loading branch information
moyanxinxu authored Dec 7, 2024
1 parent 85458e3 commit e98f649
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 15 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ main.ipynb
main.py
__pycache__
.DS_Store
data
data
model
dataset
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
"markdownlint.config": {
"MD010": false,
"MD033": false
}
},
"debugpy.debugJustMyCode": false,
// "debugpy.debugJustMyCode": true,
}
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@
| 姓名 | 职责 | 简介 |
| :----- | :--------------------- | :----------------- |
| 田健翔 | 项目负责人 | 内容创作者 |
| 于小敏 | 项目指导人 | DataWhale正式成员 |
| 卢鑫斌 | 第1章(Datasets)贡献者 | 内容创作者 |
| 胥佳程 | 第3章(PEFT)贡献者 | 内容创作者 |
| 秦子涵 | 第5章(Diffusers)贡献者 | 内容创作者 |
| 陈凯歌 | 第7章(Gradio)贡献者 | 内容创作者 |
| 刘硕 | 第7章(Gradio)贡献者 | 内容创作者 |

- PEFT
- LoRa:@[鑫民](https://github.com/fancyboi999)
Expand All @@ -57,6 +51,8 @@
- Prefix-Tuning:@[鑫民](https://github.com/fancyboi999)
- prompt-Tuning:@[鑫民](https://github.com/fancyboi999)
- P-Tuning:@[鑫民](https://github.com/fancyboi999)
- 代码案例
- 图像分类: @[陈相斌](https://github.com/chenxinxi)

项目保姆(o^^o):高增玉

Expand Down
1 change: 1 addition & 0 deletions docker-compose/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y \
unzip \
inetutils-ping \
tmux \
watch \
&& apt-get clean

COPY --from=miniconda-stage /opt/conda /opt/conda
Expand Down
1 change: 1 addition & 0 deletions docs/chapter6/code_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ title: 索引
| 文本翻译 | [中英文本翻译](./translation/translation.md) |
| 扩散去噪 | [ddpm-unet简单去噪](./ddpm-unet-mnist/ddpm-unet-mnist.md) |
| 文本分类 | [基金年报问答意图识别](./financial_report/financial_report.md) |
| 图像分类 | [菜肴图像分类](./image_classification/image_classification.md) |
221 changes: 221 additions & 0 deletions docs/chapter6/image_classification/image_classification.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
---
comments: true
title: 菜肴图像分类
---

![image_classification](./imgs/image_classification.png)

## 前言

## 代码

```python
model_checkpoint = "google/vit-base-patch16-224-in21k"
```

### 导入函数库

```python
import evaluate
import numpy as np
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
Trainer,
TrainingArguments,
)
```

### 读取数据集

```python
dataset = load_dataset("food101", split="train[:5000]")

labels = dataset.features["label"].names

label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = i
id2label[i] = label
```

下面是数据集`food101`的数据集主页。

<iframe
src="https://huggingface.co/datasets/ethz/food101/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>
### 加载模型

```python
model = AutoModelForImageClassification.from_pretrained(
model_checkpoint,
label2id=label2id,
id2label=id2label,
# provide this in case you're planning to fine-tune an already fine-tuned checkpoint
ignore_mismatched_sizes=True,
)

config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["query", "value"],
lora_dropout=0.1,
bias="none",
modules_to_save=["classifier"],
)
lora_model = get_peft_model(model, config)
```

使用参数高效微调后打印可训练参数如下:

```python title="model.print_trainable_parameters()"
trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.7713
```

### 加载预处理器

```python
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
```

### 定义数据转换

```python
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(image_processor.size["height"]),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)

val_transforms = Compose(
[
Resize(image_processor.size["height"]),
CenterCrop(image_processor.size["height"]),
ToTensor(),
normalize,
]
)

def preprocess_train(example_batch):
"""Apply train_transforms across a batch."""
example_batch["pixel_values"] = [
train_transforms(image.convert("RGB")) for image in example_batch["image"]
]
return example_batch


def preprocess_val(example_batch):
"""Apply val_transforms across a batch."""
example_batch["pixel_values"] = [
val_transforms(image.convert("RGB")) for image in example_batch["image"]
]
return example_batch
```

### 数据预处理

```python
splits = dataset.train_test_split(test_size=0.1)
train_ds = splits["train"]
val_ds = splits["test"]

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)
```

### 定义评价指标

```python
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
"""Computes accuracy on a batch of predictions"""
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
```

### 定义动态数据整理

```python
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
```

### 定义训练参数

```python
args = TrainingArguments(
"vit-finetuned-lora-food101",
remove_unused_columns=False,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=2,
learning_rate=5e-3,
per_device_train_batch_size=128,
gradient_accumulation_steps=4,
per_device_eval_batch_size=128,
fp16=True,
num_train_epochs=5,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
label_names=["labels"],
use_cpu=False,
)
```

### 定义训练器

```python
trainer = Trainer(
lora_model,
args,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=image_processor,
compute_metrics=compute_metrics,
data_collator=collate_fn,
)
```

### 训练与评估

```python
trainer.train()
trainer.evaluate(val_ds)
```

下面是训练时的过程结果。

| 轮次 | 评估损失 | 评估准确率 |
| ---- | -------- | ---------- |
| 0.8 | 4.0372 | 0.80 |
| 1.6 | 3.5086 | 0.876 |
| 2.4 | 3.0289 | 0.896 |
| 4.0 | 2.4545 | 0.908 |

## 参考资料

待补充
Loading

0 comments on commit e98f649

Please sign in to comment.