-
Notifications
You must be signed in to change notification settings - Fork 223
/
Copy pathpredict.py
60 lines (52 loc) · 2.4 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Prediction interface for Cog ⚙️
# https://cog.run/python
from cog import BasePredictor, Input, Path
import os
import time
import subprocess
MODEL_CACHE = "checkpoints"
MODEL_URL = "https://weights.replicate.delivery/default/chunyu-li/LatentSync/model.tar"
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
# Download the model weights
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)
# Soft links for the auxiliary models
os.system("mkdir -p ~/.cache/torch/hub/checkpoints")
os.system("ln -s $(pwd)/checkpoints/auxiliary/2DFAN4-cd938726ad.zip ~/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip")
os.system("ln -s $(pwd)/checkpoints/auxiliary/s3fd-619a316812.pth ~/.cache/torch/hub/checkpoints/s3fd-619a316812.pth")
os.system("ln -s $(pwd)/checkpoints/auxiliary/vgg16-397923af.pth ~/.cache/torch/hub/checkpoints/vgg16-397923af.pth")
def predict(
self,
video: Path = Input(
description="Input video", default=None
),
audio: Path = Input(
description="Input audio to ", default=None
),
guidance_scale: float = Input(
description="Guidance scale", ge=0, le=10, default=1.0
),
seed: int = Input(
description="Set to 0 for Random seed", default=0
)
) -> Path:
"""Run a single prediction on the model"""
if seed <= 0:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
video_path = str(video)
audio_path = str(audio)
config_path = "configs/unet/second_stage.yaml"
ckpt_path = "checkpoints/latentsync_unet.pt"
output_path = "/tmp/video_out.mp4"
# Run the following command:
os.system(f"python -m scripts.inference --unet_config_path {config_path} --inference_ckpt_path {ckpt_path} --guidance_scale {str(guidance_scale)} --video_path {video_path} --audio_path {audio_path} --video_out_path {output_path} --seed {seed}")
return Path(output_path)