forked from google/prompt-to-prompt
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_elite.py
31 lines (29 loc) · 901 Bytes
/
run_elite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
if __name__ == "__main__":
device = "cuda:0"
from pipeline_elite import EliteGlobalPipeline
from utils import load_images, image_grid
use_fp16 = True
bs = 4
revision = "fp16" if use_fp16 else "fp32"
pipe = EliteGlobalPipeline.from_pretrained(
mapper_model_path='./checkpoints/global_mapper.pt',
revision=revision,
)
pipe.to(device)
ref_images = ["./images_elite/1.jpg"] * bs
latents = torch.randn(
(bs, 4, 64, 64), generator=torch.manual_seed(42),
)
syn_images = pipe(
prompt=["a photo of a *"] * bs,
placeholder_token='*',
ref_image=ref_images,
guidance_scale=5,
eta=0,
num_inference_steps=50,
token_index="0",
latents=latents,
).images
syn_image = image_grid(syn_images, 1, bs)
syn_image.save(f"test_elite_pipeline_{revision}.png")