Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add files via upload #136

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
## 📜 要求
本项目包括DialogGen(一个提示增强模型)和Hunyuan-DiT(文生图模型)。

下表展示了运行本模型时的环境要求(batch size=1):

| 模型 | 是否加载4bit量化(DialogGen) | 最大GPU显存 | 可支持的GPU |
|:------------------------:|:---------------------:|:-----------:|:-----------------:|
| DialogGen + Hunyuan-DiT | ✗ | 32G | A100 |
| DialogGen + Hunyuan-DiT | ✓ | 22G | A100 |
| Hunyuan-DiT | - | 11G | A100 |
| Hunyuan-DiT | - | 14G | RTX3090/RTX4090 |

* 需要使用支持CUDA的英伟达GPU:
* 本项目已经测试能够在V100和A100显卡上运行。
* **最小GPU显存**:GPU最小显存至少为11GB。
* **推荐**:我们推荐使用32GB显存的显卡,以获得更好的生成质量。
* 测试采用的操作系统:Linux

## 🛠️ 依赖和安装

首先,克隆本项目:
```bash
git clone https://github.com/tencent/HunyuanDiT
cd HunyuanDiT
```


我们提供了一个 `environment.yml`文件用于创建Conda环境。
Conda的安装指引可以参考如下链接: [here](https://docs.anaconda.com/free/miniconda/index.html).


```bash
# 1. 准备conda环境
conda env create -f environment.yml

# 2. 激活环境
conda activate HunyuanDiT

# 3. 安装pip依赖
python -m pip install -r requirements.txt

# 4. (可选的) 安装 flash attention v2 用于加速(要求CUDA 11.6或以上版本)
python -m pip install git+https://github.com/Dao-AILab/[email protected]
```

我们推荐使用 CUDA versions 11.7 和 12.0+ 版本。



## 🧱 下载预训练模型
为了下载模型,首先请安装huggingface-cli。(指引细节可以参考如下链接:[here](https://huggingface.co/docs/huggingface_hub/guides/cli).)

```shell
python -m pip install "huggingface_hub[cli]"
```

然后采用如下命令下载模型:

```shell
# 创建一个名为'ckpts'的文件夹,该文件夹下保存模型权重,是运行该demo的先行条件
mkdir ckpts
# 采用 huggingface-cli工具下载模型
# 下载时间可能为10分钟到1小时,取决于你的网络条件。
huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./ckpts
```


<details>
<summary>💡使用huggingface-cli的小技巧 (网络问题)</summary>

##### 1. 使用 HF 镜像

如果在中国境内的下载速度较慢,你可以使用镜像加速下载过程,例如
```shell
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./ckpts
```

##### 2. 重新下载

`huggingface-cli` 支持重新下载。如果下载过程被中断,你只需要重新运行下载命令,恢复下载进程。

注意:如果在下载过程中发生类似`No such file or directory: 'ckpts/.huggingface/.gitignore.lock'`的错误,你可以忽略这个错误,
并重新执行以下命令: `huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./ckpts`

</details>

---

所有的模型将会自动下载。如果想要了解更多关于模型的信息,请查阅Hugging Face的项目:[here](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT)。

| 模型 | 参数量 | 下载链接 |
|:------------------:|:------:|:-------------------------------------------------------------------------------------------------------:|
| mT5 | 1.6B | [mT5](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/mt5) |
| CLIP | 350M | [CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/clip_text_encoder) |
| DialogGen | 7.0B | [DialogGen](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/dialoggen) |
| sdxl-vae-fp16-fix | 83M | [sdxl-vae-fp16-fix](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/sdxl-vae-fp16-fix) |
| Hunyuan-DiT | 1.5B | [Hunyuan-DiT](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/model) |




46 changes: 29 additions & 17 deletions comfyui-hydit/hydit/data_loader/arrow_load_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
uncond_p_t5=0.0,
text_ctx_len_t5=256,
tokenizer_t5=None,
use_t5=False
):
self.args = args
self.resolution = resolution
Expand All @@ -59,6 +60,7 @@ def __init__(self,
self.tokenizer = tokenizer

# t5 params
self.use_t5 = use_t5
self.uncond_p_t5 = uncond_p_t5
self.text_ctx_len_t5 = text_ctx_len_t5
self.tokenizer_t5 = tokenizer_t5
Expand Down Expand Up @@ -230,26 +232,36 @@ def __getitem__(self, ind):
else:
description = self.get_text(ind)

# Get text for t5
if random.random() < self.uncond_p_t5:
description_t5 = ""
else:
description_t5 = self.get_text(ind)
if self.use_t5:
# Get text for t5
if random.random() < self.uncond_p_t5:
description_t5 = ""
else:
description_t5 = self.get_text(ind)

# Use encoder to embed tokens online
text, text_embedding, text_embedding_mask = self.get_text_info_with_encoder(description)

text_t5, text_embedding_t5, text_embedding_mask_t5 = self.get_text_info_with_encoder_t5(description_t5)

original_pil_image, kwargs = self.get_image_with_hwxy(ind)

# Use encoder to embed tokens online
text, text_embedding, text_embedding_mask = self.get_text_info_with_encoder(description)

text_t5, text_embedding_t5, text_embedding_mask_t5 = self.get_text_info_with_encoder_t5(description_t5)
return (
original_pil_image,
text_embedding.clone().detach(),
text_embedding_mask.clone().detach(),
text_embedding_t5.clone().detach(),
text_embedding_mask_t5.clone().detach(),
{k: torch.tensor(np.array(v)).clone().detach() for k, v in kwargs.items()},
)
if self.use_t5:
return (
original_pil_image,
text_embedding.clone().detach(),
text_embedding_mask.clone().detach(),
text_embedding_t5.clone().detach(),
text_embedding_mask_t5.clone().detach(),
{k: torch.tensor(np.array(v)).clone().detach() for k, v in kwargs.items()},
)
else:
return (
original_pil_image,
text_embedding.clone().detach(),
text_embedding_mask.clone().detach(),
{k: torch.tensor(np.array(v)).clone().detach() for k, v in kwargs.items()},
)

def __len__(self):
return len(self.index_manager)
1 change: 1 addition & 0 deletions comfyui-hydit/hydit/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ sh $(dirname "$0")/run_g.sh \
--deepspeed \
--deepspeed-optimizer \
--use-zero-stage 2 \
--use_t5 True \
"$@"
143 changes: 95 additions & 48 deletions comfyui-hydit/hydit/train_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,12 @@ def save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"):
return checkpoint_path

@torch.no_grad()
def prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img):
image, text_embedding, text_embedding_mask, text_embedding_t5, text_embedding_mask_t5, kwargs = batch
def prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img, use_t5):
# image, text_embedding, text_embedding_mask, text_embedding_t5, text_embedding_mask_t5, kwargs = batch
if use_t5:
image, text_embedding, text_embedding_mask, text_embedding_t5, text_embedding_mask_t5, kwargs = batch
else:
image, text_embedding, text_embedding_mask, kwargs = batch

# clip & mT5 text embedding
text_embedding = text_embedding.to(device)
Expand All @@ -118,15 +122,16 @@ def prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5
text_embedding.to(device),
attention_mask=text_embedding_mask.to(device),
)[0]
text_embedding_t5 = text_embedding_t5.to(device).squeeze(1)
text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1)
with torch.no_grad():
output_t5 = text_encoder_t5(
input_ids=text_embedding_t5,
attention_mask=text_embedding_mask_t5 if T5_ENCODER['attention_mask'] else None,
output_hidden_states=True
)
encoder_hidden_states_t5 = output_t5['hidden_states'][T5_ENCODER['layer_index']].detach()
if use_t5:
text_embedding_t5 = text_embedding_t5.to(device).squeeze(1)
text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1)
with torch.no_grad():
output_t5 = text_encoder_t5(
input_ids=text_embedding_t5,
attention_mask=text_embedding_mask_t5 if T5_ENCODER['attention_mask'] else None,
output_hidden_states=True
)
encoder_hidden_states_t5 = output_t5['hidden_states'][T5_ENCODER['layer_index']].detach()

# additional condition
image_meta_size = kwargs['image_meta_size'].to(device)
Expand All @@ -147,16 +152,27 @@ def prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5
cos_cis_img, sin_cis_img = freqs_cis_img[reso]

# Model conditions
model_kwargs = dict(
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
)
if use_t5:
# Model conditions
model_kwargs = dict(
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
)
else:
model_kwargs = dict(
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
)

return latents, model_kwargs

Expand All @@ -167,11 +183,12 @@ def main(args):
assert torch.cuda.is_available(), "Training currently requires at least one GPU."

dist.init_process_group("nccl")

world_size = dist.get_world_size()
batch_size = args.batch_size
grad_accu_steps = args.grad_accu_steps
global_batch_size = world_size * batch_size * grad_accu_steps

use_t5 = args.use_t5 # new add
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * world_size + rank
Expand Down Expand Up @@ -280,22 +297,25 @@ def main(args):
# Setup BERT tokenizer:
logger.info(f" Loading Bert tokenizer from {TOKENIZER}")
tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
# Setup T5 text encoder
from hydit.modules.text_encoder import MT5Embedder
mt5_path = T5_ENCODER['MT5']
embedder_t5 = MT5Embedder(mt5_path, torch_dtype=T5_ENCODER['torch_dtype'], max_length=args.text_len_t5)
tokenizer_t5 = embedder_t5.tokenizer
text_encoder_t5 = embedder_t5.model
if use_t5:
# Setup T5 text encoder
from hydit.modules.text_encoder import MT5Embedder
mt5_path = T5_ENCODER['MT5']
embedder_t5 = MT5Embedder(mt5_path, torch_dtype=T5_ENCODER['torch_dtype'], max_length=args.text_len_t5)
tokenizer_t5 = embedder_t5.tokenizer
text_encoder_t5 = embedder_t5.model

if args.extra_fp16:
logger.info(f" Using fp16 for extra modules: vae, text_encoder")
vae = vae.half().to(device)
text_encoder = text_encoder.half().to(device)
text_encoder_t5 = text_encoder_t5.half().to(device)
if use_t5:
text_encoder_t5 = text_encoder_t5.half().to(device)
else:
vae = vae.to(device)
text_encoder = text_encoder.to(device)
text_encoder_t5 = text_encoder_t5.to(device)
if use_t5:
text_encoder_t5 = text_encoder_t5.to(device)

logger.info(f" Optimizer parameters: lr={args.lr}, weight_decay={args.weight_decay}")
logger.info(" Using deepspeed optimizer")
Expand All @@ -308,23 +328,40 @@ def main(args):
logger.info(f"Building Streaming Dataset.")
logger.info(f" Loading index file {args.index_file} (v2)")

dataset = TextImageArrowStream(args=args,
resolution=image_size[0],
random_flip=args.random_flip,
log_fn=logger.info,
index_file=args.index_file,
multireso=args.multireso,
batch_size=batch_size,
world_size=world_size,
random_shrink_size_cond=args.random_shrink_size_cond,
merge_src_cond=args.merge_src_cond,
uncond_p=args.uncond_p,
text_ctx_len=args.text_len,
tokenizer=tokenizer,
uncond_p_t5=args.uncond_p_t5,
text_ctx_len_t5=args.text_len_t5,
tokenizer_t5=tokenizer_t5,
)
if use_t5:
dataset = TextImageArrowStream(args=args,
resolution=image_size[0],
random_flip=args.random_flip,
log_fn=logger.info,
index_file=args.index_file,
multireso=args.multireso,
batch_size=batch_size,
world_size=world_size,
random_shrink_size_cond=args.random_shrink_size_cond,
merge_src_cond=args.merge_src_cond,
uncond_p=args.uncond_p,
text_ctx_len=args.text_len,
tokenizer=tokenizer,
uncond_p_t5=args.uncond_p_t5,
text_ctx_len_t5=args.text_len_t5,
tokenizer_t5=tokenizer_t5,
use_t5=use_t5
)
else:
dataset = TextImageArrowStream(args=args,
resolution=image_size[0],
random_flip=args.random_flip,
log_fn=logger.info,
index_file=args.index_file,
multireso=args.multireso,
batch_size=batch_size,
world_size=world_size,
random_shrink_size_cond=args.random_shrink_size_cond,
merge_src_cond=args.merge_src_cond,
uncond_p=args.uncond_p,
text_ctx_len=args.text_len,
tokenizer=tokenizer
)
if args.multireso:
sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed,
shuffle=False, drop_last=True, batch_size=batch_size)
Expand Down Expand Up @@ -445,7 +482,17 @@ def main(args):
for batch in loader:
step += 1

latents, model_kwargs = prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img)
# latents, model_kwargs = prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img)
if use_t5:
latents, model_kwargs = prepare_model_inputs(args=args, batch=batch, device=device, vae=vae,
text_encoder=text_encoder,
text_encoder_t5=text_encoder_t5,
freqs_cis_img=freqs_cis_img, use_t5=use_t5)
else:
latents, model_kwargs = prepare_model_inputs(args=args, batch=batch, device=device, vae=vae,
text_encoder=text_encoder,
text_encoder_t5=None,
freqs_cis_img=freqs_cis_img)

# training model by deepspeed while use fp16
if args.use_fp16:
Expand Down
Loading