Skip to content

Commit

Permalink
feat: allow env components
Browse files Browse the repository at this point in the history
  • Loading branch information
R3gm committed Dec 2, 2024
1 parent be22637 commit a9fe2dc
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions stablepy/diffusers_vanilla/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ def __init__(
retain_task_model_in_cache=True,
device=None,
controlnet_model="Automatic",
env_components=None,
):
super().__init__()

self.env_components = env_components

self.device = (
torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device is None
Expand Down Expand Up @@ -273,17 +277,15 @@ def switch_pipe_class(
model_components["tokenizer_2"] = self.pipe.tokenizer_2
model_components["transformer"] = self.pipe.transformer

cpu_device = "cpu"

if task_name == "txt2img":
from diffusers import FluxPipeline
self.pipe = FluxPipeline(**model_components).to(cpu_device)
self.pipe = FluxPipeline(**model_components)
elif task_name == "inpaint":
from .extra_pipe.flux.pipeline_flux_inpaint import FluxInpaintPipeline
self.pipe = FluxInpaintPipeline(**model_components).to(cpu_device)
self.pipe = FluxInpaintPipeline(**model_components)
elif task_name == "img2img":
from .extra_pipe.flux.pipeline_flux_img2img import FluxImg2ImgPipeline
self.pipe = FluxImg2ImgPipeline(**model_components).to(cpu_device)
self.pipe = FluxImg2ImgPipeline(**model_components)
else:
from .extra_pipe.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from .extra_pipe.flux.controlnet_flux import FluxControlNetModel
Expand All @@ -297,11 +299,11 @@ def switch_pipe_class(
model_components["controlnet"] = FluxControlNetModel.from_pretrained(
model_id,
torch_dtype=torch.bfloat16
).to(cpu_device)
)

self.pipe = FluxControlNetPipeline(
**model_components
).to(cpu_device)
)
return None

else:
Expand Down Expand Up @@ -542,11 +544,19 @@ def load_pipe(
except Exception as e:
logger.debug(e)

self.pipe = DiffusionPipeline.from_pretrained(
repo_flux_model,
transformer=transformer,
torch_dtype=self.type_model_precision,
)
if self.env_components is not None:
from diffusers import FluxPipeline
logger.debug(f"Env components: {self.env_components.keys()}")
self.pipe = FluxPipeline(
transformer=transformer,
**self.env_components,
)
else:
self.pipe = DiffusionPipeline.from_pretrained(
repo_flux_model,
transformer=transformer,
torch_dtype=self.type_model_precision,
)

if not self.pipe.transformer.config.guidance_embeds:
self.pipe.scheduler.register_to_config(
Expand Down Expand Up @@ -1277,7 +1287,7 @@ def __call__(
t2i_adapter_conditioning_scale: float = 1.0,
t2i_adapter_conditioning_factor: float = 1.0,

upscaler_model_path: Optional[str] = None, # add latent
upscaler_model_path: Optional[str] = None,
upscaler_increases_size: float = 1.5,
esrgan_tile: int = 100,
esrgan_tile_overlap: int = 10,
Expand Down Expand Up @@ -2289,7 +2299,7 @@ def _execution_device(self):
"prompt": None,
"negative_prompt": None,
"num_inference_steps": hires_steps,
"guidance_scale": hires_guidance_scale if hires_guidance_scale > -0.1 else guidance_scale,
"guidance_scale": hires_guidance_scale if hires_guidance_scale >= 0 else guidance_scale,
"clip_skip": None,
"strength": hires_denoising_strength,
}
Expand Down

0 comments on commit a9fe2dc

Please sign in to comment.