Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental Redux conditioning for Flux Lora training #1838

Draft
wants to merge 2 commits into
base: sd3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from accelerate import Accelerator

from library.device_utils import clean_memory_on_device, init_ipex
from library.strategy_flux import move_vision_encoder_to_device

init_ipex()

Expand Down Expand Up @@ -190,6 +191,8 @@ def get_text_encoder_outputs_caching_strategy(self, args):
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
vision_cond_size=args.vision_cond_downsample,
redux_path=args.redux_model_path
)
else:
return None
Expand Down Expand Up @@ -250,6 +253,7 @@ def cache_text_encoder_outputs_if_needed(
text_encoders[0].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu")
move_vision_encoder_to_device("cpu")
clean_memory_on_device(accelerator.device)

if not args.lowram:
Expand Down Expand Up @@ -372,6 +376,15 @@ def get_noise_pred_and_target(
if not args.apply_t5_attn_mask:
t5_attn_mask = None

if args.vision_cond_dropout < 1.0:
if random.uniform(0,1) > args.vision_cond_dropout:
vision_encoder_conds = batch.get("vision_encoder_outputs_list", None)
vis_t5_out, vis_txt_ids, vis_attn_mask = vision_encoder_conds
t5_out = torch.cat([t5_out, vis_t5_out], dim=1)
txt_ids = torch.cat([txt_ids, vis_txt_ids], dim=1)
if args.apply_t5_attn_mask:
t5_attn_mask = torch.cat([t5_attn_mask, vis_attn_mask], dim=1)

def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
Expand Down
19 changes: 19 additions & 0 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,22 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)

parser.add_argument(
"--redux_model_path",
type=str,
help="path to Redux model (*.sft or *.safetensors), should be float16",
)
parser.add_argument(
"--vision_cond_downsample",
type=int,
default=0,
help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.",
)

parser.add_argument(
"--vision_cond_dropout",
type=float,
default=1.0,
help="Probability of dropout for Redux conditioning.",
)
91 changes: 88 additions & 3 deletions library/strategy_flux.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import glob
from math import sqrt
from typing import Any, List, Optional, Tuple, Union

import safetensors
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
import PIL.Image
from transformers import CLIPTokenizer, T5TokenizerFast, SiglipVisionModel, AutoProcessor

from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
Expand All @@ -20,6 +23,38 @@
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"


# FIXME: this is a very hacky way of handling the encoder model
siglip_model = None
siglip_processor = None
redux_encoder = None

def move_vision_encoder_to_device(device):
if siglip_model is not None:
siglip_model.to(device)
if redux_encoder is not None:
redux_encoder.to(device)


class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)

def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x


class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
Expand Down Expand Up @@ -95,10 +130,13 @@ def __init__(
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
vision_cond_size: int = 0,
redux_path: str = None,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask

self.vision_cond_size = vision_cond_size
self.redux_path = redux_path
self.warn_fp8_weights = False

def get_outputs_npz_path(self, image_abs_path: str) -> str:
Expand Down Expand Up @@ -142,6 +180,49 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
return [l_pooled, t5_out, txt_ids, t5_attn_mask]

def encode_vision(self, infos, grid_size, t5_out, txt_ids):
global siglip_model
global siglip_processor
global redux_encoder

if siglip_model is None:
model_id = "google/siglip-so400m-patch14-384"
siglip_model = SiglipVisionModel.from_pretrained(
model_id, attn_implementation="sdpa", device_map="cuda")
siglip_processor = AutoProcessor.from_pretrained(model_id)

if redux_encoder is None:
if self.redux_path is None:
raise Exception("Vision encoding requires Redux model, but no file was provided.")
model_data = safetensors.torch.load_file(self.redux_path, device=torch.device("cpu").type)
redux_encoder = ReduxImageEncoder()
redux_encoder.load_state_dict(model_data)
redux_encoder = redux_encoder.to(device="cuda")

bsz = txt_ids.shape[0]
imgs = [PIL.Image.open(nfo.absolute_path) for nfo in infos]
siglip_in = siglip_processor(images=imgs, padding="max_length", return_tensors="pt")
siglip_in = siglip_in.to(device="cuda")

with torch.no_grad(), torch.autocast("cuda"):
siglip_out = siglip_model(**siglip_in)
new_embed = redux_encoder(siglip_out.last_hidden_state).float()
(b, t, h) = new_embed.shape
s = int(sqrt(t))
new_embed = torch.nn.functional.interpolate(new_embed.view(b, s, s, h).transpose(1, -1),
size=(grid_size, grid_size),
mode="bicubic")
new_embed = new_embed.transpose(1, -1).reshape(b, -1, h).cpu().numpy()
new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2]))
attn_mask = np.ones((bsz, new_embed.shape[1]))

for i, info in enumerate(infos):
new_embed_i = new_embed[i]
new_ids_i = new_ids[i]
attn_mask_i = attn_mask[i]
info.vision_encoder_outputs = (new_embed_i, new_ids_i, attn_mask_i)


def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
Expand Down Expand Up @@ -173,6 +254,10 @@ def cache_batch_outputs(
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()

if self.vision_cond_size > 0:
assert self.vision_cond_size <= 27, "Downsample ratio must not be greater than 27."
self.encode_vision(infos, self.vision_cond_size, t5_out, txt_ids)

for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
Expand Down
7 changes: 7 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,

self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime

self.vision_encoder_outputs: Optional[List[torch.Tensor]] = None


class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
Expand Down Expand Up @@ -1497,6 +1499,7 @@ def __getitem__(self, index):
target_sizes_hw = []
flippeds = [] # 変数名が微妙
text_encoder_outputs_list = []
vision_encoder_outputs_list = []
custom_attributes = []

for image_key in bucket[image_index : image_index + bucket_batch_size]:
Expand Down Expand Up @@ -1621,6 +1624,9 @@ def __getitem__(self, index):
text_encoder_outputs = None
input_ids = None

if image_info.vision_encoder_outputs is not None:
vision_encoder_outputs_list.append(image_info.vision_encoder_outputs)

if image_info.text_encoder_outputs is not None:
# cached
text_encoder_outputs = image_info.text_encoder_outputs
Expand Down Expand Up @@ -1676,6 +1682,7 @@ def none_or_stack_elements(tensors_list, converter):
example["custom_attributes"] = custom_attributes # may be list of empty dict
example["loss_weights"] = torch.FloatTensor(loss_weights)
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
example["vision_encoder_outputs_list"] = none_or_stack_elements(vision_encoder_outputs_list, torch.FloatTensor)
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)

# if one of alpha_masks is not None, we need to replace None with ones
Expand Down
Loading