diff --git a/flux_train_network.py b/flux_train_network.py index 75e975bae..297d02078 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -8,6 +8,7 @@ from accelerate import Accelerator from library.device_utils import clean_memory_on_device, init_ipex +from library.strategy_flux import move_vision_encoder_to_device init_ipex() @@ -190,6 +191,8 @@ def get_text_encoder_outputs_caching_strategy(self, args): args.skip_cache_check, is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, + vision_cond_size=args.vision_cond_downsample, + redux_path=args.redux_model_path ) else: return None @@ -250,6 +253,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to("cpu") logger.info("move t5XXL back to cpu") text_encoders[1].to("cpu") + move_vision_encoder_to_device("cpu") clean_memory_on_device(accelerator.device) if not args.lowram: @@ -372,6 +376,15 @@ def get_noise_pred_and_target( if not args.apply_t5_attn_mask: t5_attn_mask = None + if args.vision_cond_dropout < 1.0: + if random.uniform(0,1) > args.vision_cond_dropout: + vision_encoder_conds = batch.get("vision_encoder_outputs_list", None) + vis_t5_out, vis_txt_ids, vis_attn_mask = vision_encoder_conds + t5_out = torch.cat([t5_out, vis_t5_out], dim=1) + txt_ids = torch.cat([txt_ids, vis_txt_ids], dim=1) + if args.apply_t5_attn_mask: + t5_attn_mask = torch.cat([t5_attn_mask, vis_attn_mask], dim=1) + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: # normal forward diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5cf..a70f6932c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -617,3 +617,22 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + + parser.add_argument( + "--redux_model_path", + type=str, + help="path to Redux model (*.sft or *.safetensors), should be float16", + ) + parser.add_argument( + "--vision_cond_downsample", + type=int, + default=0, + help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.", + ) + + parser.add_argument( + "--vision_cond_dropout", + type=float, + default=1.0, + help="Probability of dropout for Redux conditioning.", + ) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5e65927f8..627a8d420 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -1,9 +1,12 @@ import os -import glob +from math import sqrt from typing import Any, List, Optional, Tuple, Union + +import safetensors import torch import numpy as np -from transformers import CLIPTokenizer, T5TokenizerFast +import PIL.Image +from transformers import CLIPTokenizer, T5TokenizerFast, SiglipVisionModel, AutoProcessor from library import flux_utils, train_util from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy @@ -20,6 +23,38 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" +# FIXME: this is a very hacky way of handling the encoder model +siglip_model = None +siglip_processor = None +redux_encoder = None + +def move_vision_encoder_to_device(device): + if siglip_model is not None: + siglip_model.to(device) + if redux_encoder is not None: + redux_encoder.to(device) + + +class ReduxImageEncoder(torch.nn.Module): + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.redux_dim = redux_dim + self.device = device + self.dtype = dtype + self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) + + def forward(self, sigclip_embeds) -> torch.Tensor: + projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds))) + return projected_x + + class FluxTokenizeStrategy(TokenizeStrategy): def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: self.t5xxl_max_length = t5xxl_max_length @@ -95,10 +130,13 @@ def __init__( skip_disk_cache_validity_check: bool, is_partial: bool = False, apply_t5_attn_mask: bool = False, + vision_cond_size: int = 0, + redux_path: str = None, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) self.apply_t5_attn_mask = apply_t5_attn_mask - + self.vision_cond_size = vision_cond_size + self.redux_path = redux_path self.warn_fp8_weights = False def get_outputs_npz_path(self, image_abs_path: str) -> str: @@ -142,6 +180,49 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: # apply_t5_attn_mask should be same as self.apply_t5_attn_mask return [l_pooled, t5_out, txt_ids, t5_attn_mask] + def encode_vision(self, infos, grid_size, t5_out, txt_ids): + global siglip_model + global siglip_processor + global redux_encoder + + if siglip_model is None: + model_id = "google/siglip-so400m-patch14-384" + siglip_model = SiglipVisionModel.from_pretrained( + model_id, attn_implementation="sdpa", device_map="cuda") + siglip_processor = AutoProcessor.from_pretrained(model_id) + + if redux_encoder is None: + if self.redux_path is None: + raise Exception("Vision encoding requires Redux model, but no file was provided.") + model_data = safetensors.torch.load_file(self.redux_path, device=torch.device("cpu").type) + redux_encoder = ReduxImageEncoder() + redux_encoder.load_state_dict(model_data) + redux_encoder = redux_encoder.to(device="cuda") + + bsz = txt_ids.shape[0] + imgs = [PIL.Image.open(nfo.absolute_path) for nfo in infos] + siglip_in = siglip_processor(images=imgs, padding="max_length", return_tensors="pt") + siglip_in = siglip_in.to(device="cuda") + + with torch.no_grad(), torch.autocast("cuda"): + siglip_out = siglip_model(**siglip_in) + new_embed = redux_encoder(siglip_out.last_hidden_state).float() + (b, t, h) = new_embed.shape + s = int(sqrt(t)) + new_embed = torch.nn.functional.interpolate(new_embed.view(b, s, s, h).transpose(1, -1), + size=(grid_size, grid_size), + mode="bicubic") + new_embed = new_embed.transpose(1, -1).reshape(b, -1, h).cpu().numpy() + new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2])) + attn_mask = np.ones((bsz, new_embed.shape[1])) + + for i, info in enumerate(infos): + new_embed_i = new_embed[i] + new_ids_i = new_ids[i] + attn_mask_i = attn_mask[i] + info.vision_encoder_outputs = (new_embed_i, new_ids_i, attn_mask_i) + + def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): @@ -173,6 +254,10 @@ def cache_batch_outputs( txt_ids = txt_ids.cpu().numpy() t5_attn_mask = tokens_and_masks[2].cpu().numpy() + if self.vision_cond_size > 0: + assert self.vision_cond_size <= 27, "Downsample ratio must not be greater than 27." + self.encode_vision(infos, self.vision_cond_size, t5_out, txt_ids) + for i, info in enumerate(infos): l_pooled_i = l_pooled[i] t5_out_i = t5_out[i] diff --git a/library/train_util.py b/library/train_util.py index a35388fee..30151dafb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -176,6 +176,8 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.vision_encoder_outputs: Optional[List[torch.Tensor]] = None + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -1497,6 +1499,7 @@ def __getitem__(self, index): target_sizes_hw = [] flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] + vision_encoder_outputs_list = [] custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: @@ -1621,6 +1624,9 @@ def __getitem__(self, index): text_encoder_outputs = None input_ids = None + if image_info.vision_encoder_outputs is not None: + vision_encoder_outputs_list.append(image_info.vision_encoder_outputs) + if image_info.text_encoder_outputs is not None: # cached text_encoder_outputs = image_info.text_encoder_outputs @@ -1676,6 +1682,7 @@ def none_or_stack_elements(tensors_list, converter): example["custom_attributes"] = custom_attributes # may be list of empty dict example["loss_weights"] = torch.FloatTensor(loss_weights) example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["vision_encoder_outputs_list"] = none_or_stack_elements(vision_encoder_outputs_list, torch.FloatTensor) example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) # if one of alpha_masks is not None, we need to replace None with ones