-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathstylize_video.py
163 lines (136 loc) Β· 5.38 KB
/
stylize_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import numpy as np
import cv2
import torch
from models import loss_models, transformation_models
from utils import preprocess_batch, deprocess_batch
from torchvision.transforms.functional import resize
from argument_parsers import stylize_video_parser
device = {torch.has_cuda: "cuda", torch.has_mps: "mps"}.get(True, "cpu")
def stylize_video(video_path, model_path, save_path, frames_per_step, image_size):
# load the video
video = cv2.VideoCapture(video_path)
# get the video's dimensions and frame count
width_original = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height_original = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(video.get(cv2.CAP_PROP_FPS))
frames_to_capture = frame_count
print(f"source video dimensions: {width_original}x{height_original}")
print(f"source video frame count: {frame_count}")
print(f"source video fps: {fps}\n")
# get the video output dimensions
width = width_original
height = height_original
if image_size:
min_dim = min(width_original, height_original)
width = int(width_original / min_dim * image_size)
height = int(height_original / min_dim * image_size)
print(f"output video dimensions: {width}x{height}")
print(f"output video frame count: {frames_to_capture}")
print(f"output video fps: {fps}\n")
# setting up the model
transformation_model = transformation_models.TransformationModel().to(device).eval()
# loading weights of pretrained model
checkpoint = torch.load(model_path, map_location=device)
transformation_model.load_state_dict(checkpoint["model_state_dict"])
transformation_model.requires_grad_(False)
# partition the frames into batches of size 64
frames_batch_size = 64
# save the frames as a video
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(
save_path,
fourcc,
float(fps),
(width, height),
)
# use the first iteration to get the frame sizes
for i in range(0, frames_to_capture, frames_batch_size):
# make sure the last batch has the correct size
batch_size = frames_batch_size
if i + frames_batch_size > frames_to_capture:
batch_size = frames_to_capture - i
# create a batch of empty frames
frames_batch = np.empty(
(batch_size, height_original, width_original, 3), dtype=np.uint8
)
print(f"stylizing frames <{i}-{i + batch_size}/{frames_to_capture}>")
# read the frames
frame_index = 0
ret = True
while video.isOpened() and frame_index < frames_batch_size:
ret, frame = video.read()
if ret:
frames_batch[frame_index] = frame
frame_index += 1
else:
# end of frames batch
break
stylized_batch = stylize_frames_batch(
frames_batch, transformation_model, frames_per_step, image_size
)
for styled_frame in stylized_batch:
out.write(styled_frame)
out.release()
# to add the audio back to the video, run this command in the terminal:
# ffmpeg -i {save_path} -i {video_path} -c copy -map 0:v:0 -map 1:a:0 {save_with_audio_path}
def stylize_frames_batch(
frames, transformation_model, frames_per_step, image_size=None
):
"""
Stylize a batch of frames
"""
# change the frames into torch tensors
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
# preprocess the frames to what the model expects
mean = loss_models.VGG16Loss.MEAN
std = loss_models.VGG16Loss.STD
frames = preprocess_batch(frames, mean, std)
if image_size:
frames = resize(frames, image_size)
width, height = frames.shape[3], frames.shape[2]
mean = mean.to(device)
std = std.to(device)
frames_to_capture = frames.shape[0]
# stylize the frames in batches
stylized_frames = torch.empty_like(frames)
for i in range(0, frames_to_capture, frames_per_step):
# get the batch
section = frames[i : i + frames_per_step].to(device)
# stylize the batch
stylized_section = transformation_model(section)
# depreprocess the batch
stylized_section = deprocess_batch(stylized_section, mean, std)
# for some reason the transformed image ends up having slightly different dimensions
# so we resize it to the right dimensions
stylized_section = resize(
stylized_section, (section.shape[2], section.shape[3])
)
# save the batch
stylized_frames[i : i + frames_per_step] = stylized_section
# print progress every 24 frames
if i % 24 == 0:
print(f"from batch, stylized frame [{i}/{frames_to_capture}]")
print("styled frames successfully\n")
# convert the frames back to numpy arrays
stylized_frames = (
stylized_frames.detach()
.cpu()
.permute(0, 2, 3, 1)
.mul(255)
.numpy()
.astype("uint8")
)
# colors channel is in BGR, so we convert it to RGB
stylized_frames = stylized_frames[:, :, :, ::-1]
return stylized_frames
if __name__ == "__main__":
args = stylize_video_parser()
# stylize the video
stylize_video(
args.video_path,
args.model_path,
args.save_path,
args.frames_per_step,
args.max_image_size,
)