forked from PaddlePaddle/PaddleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add serving tensort * modify msvsr infer since the second output is the result * update wav2lip model path
- Loading branch information
Showing
8 changed files
with
357 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# TensorRT预测部署教程 | ||
TensorRT是NVIDIA提出的用于统一模型部署的加速库,可以应用于V100、JETSON Xavier等硬件,它可以极大提高预测速度。Paddle TensorRT教程请参考文档[使用Paddle-TensorRT库预测](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html#) | ||
|
||
## 1. 安装PaddleInference预测库 | ||
- Python安装包,请从[这里](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-release) 下载带有tensorrt的安装包进行安装 | ||
|
||
- CPP预测库,请从[这里](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/05_inference_deployment/inference/build_and_install_lib_cn.html) 下载带有TensorRT编译的预测库 | ||
|
||
- 如果Python和CPP官网没有提供已编译好的安装包或预测库,请参考[源码安装](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html) 自行编译 | ||
|
||
**注意:** | ||
- 您的机器上TensorRT的版本需要跟您使用的预测库中TensorRT版本保持一致。 | ||
- PaddleGAN中部署预测要求TensorRT版本 > 7.0。 | ||
|
||
## 2. 导出模型 | ||
模型导出具体请参考文档[PaddleGAN模型导出教程](../EXPORT_MODEL.md)。 | ||
|
||
## 3. 开启TensorRT加速 | ||
### 3.1 配置TensorRT | ||
在使用Paddle预测库构建预测器配置config时,打开TensorRT引擎就可以了: | ||
|
||
``` | ||
config->EnableUseGpu(100, 0); // 初始化100M显存,使用GPU ID为0 | ||
config->GpuDeviceId(); // 返回正在使用的GPU ID | ||
// 开启TensorRT预测,可提升GPU预测性能,需要使用带TensorRT的预测库 | ||
config->EnableTensorRtEngine(1 << 20 /*workspace_size*/, | ||
batch_size /*max_batch_size*/, | ||
3 /*min_subgraph_size*/, | ||
AnalysisConfig::Precision::kFloat32 /*precision*/, | ||
false /*use_static*/, | ||
false /*use_calib_mode*/); | ||
``` | ||
|
||
### 3.2 TensorRT固定尺寸预测 | ||
|
||
以`msvsr`为例,使用固定尺寸输入预测: | ||
``` | ||
python tools/inference.py --model_path=/root/to/model --config-file /root/to/config --run_mode trt_fp32 --min_subgraph_size 20 --mode_type msvsr | ||
``` | ||
|
||
## 4、常见问题QA | ||
**Q:** 提示没有`tensorrt_op`</br> | ||
**A:** 请检查是否使用带有TensorRT的Paddle Python包或预测库。 | ||
|
||
**Q:** 提示`op out of memory`</br> | ||
**A:** 检查GPU是否是别人也在使用,请尝试使用空闲GPU | ||
|
||
**Q:** 提示`some trt inputs dynamic shape info not set`</br> | ||
**A:** 这是由于`TensorRT`会把网络结果划分成多个子图,我们只设置了输入数据的动态尺寸,划分的其他子图的输入并未设置动态尺寸。有两个解决方法: | ||
|
||
- 方法一:通过增大`min_subgraph_size`,跳过对这些子图的优化。根据提示,设置min_subgraph_size大于并未设置动态尺寸输入的子图中OP个数即可。 | ||
`min_subgraph_size`的意思是,在加载TensorRT引擎的时候,大于`min_subgraph_size`的OP才会被优化,并且这些OP是连续的且是TensorRT可以优化的。 | ||
|
||
- 方法二:找到子图的这些输入,按照上面方式也设置子图的输入动态尺寸。 | ||
|
||
**Q:** 如何打开日志</br> | ||
**A:** 预测库默认是打开日志的,只要注释掉`config.disable_glog_info()`就可以打开日志 | ||
|
||
**Q:** 开启TensorRT,预测时提示Slice on batch axis is not supported in TensorRT</br> | ||
**A:** 请尝试使用动态尺寸输入 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# 服务端预测部署 | ||
|
||
`PaddleGAN`训练出来的模型可以使用[Serving](https://github.com/PaddlePaddle/Serving) 部署在服务端。 | ||
本教程以在REDS数据集上用`configs/msvsr_reds.yaml`算法训练的模型进行部署。 | ||
预训练模型权重文件为[PP-MSVSR_reds_x4.pdparams](https://paddlegan.bj.bcebos.com/models/PP-MSVSR_reds_x4.pdparams) 。 | ||
|
||
## 1. 安装 paddle serving | ||
请参考[PaddleServing](https://github.com/PaddlePaddle/Serving/tree/v0.6.0) 中安装教程安装(版本>=0.6.0)。 | ||
|
||
## 2. 导出模型 | ||
PaddleGAN在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleGAN/blob/develop/deploy/EXPORT_MODEL.md) | ||
|
||
``` | ||
python tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --load /path/to/model --export_serving_model True | ||
----output_dir /path/to/output | ||
``` | ||
|
||
以上命令会在`/path/to/output`文件夹下生成一个`msvsr`文件夹: | ||
``` | ||
output | ||
│ ├── multistagevsrmodel_generator | ||
│ │ ├── multistagevsrmodel_generator.pdiparams | ||
│ │ ├── multistagevsrmodel_generator.pdiparams.info | ||
│ │ ├── multistagevsrmodel_generator.pdmodel | ||
│ │ ├── serving_client | ||
│ │ │ ├── serving_client_conf.prototxt | ||
│ │ │ ├── serving_client_conf.stream.prototxt | ||
│ │ ├── serving_server | ||
│ │ │ ├── __model__ | ||
│ │ │ ├── __params__ | ||
│ │ │ ├── serving_server_conf.prototxt | ||
│ │ │ ├── serving_server_conf.stream.prototxt | ||
│ │ │ ├── ... | ||
``` | ||
|
||
`serving_client`文件夹下`serving_client_conf.prototxt`详细说明了模型输入输出信息 | ||
`serving_client_conf.prototxt`文件内容为: | ||
``` | ||
feed_var { | ||
name: "lqs" | ||
alias_name: "lqs" | ||
is_lod_tensor: false | ||
feed_type: 1 | ||
shape: 1 | ||
shape: 2 | ||
shape: 3 | ||
shape: 180 | ||
shape: 320 | ||
} | ||
fetch_var { | ||
name: "stack_18.tmp_0" | ||
alias_name: "stack_18.tmp_0" | ||
is_lod_tensor: false | ||
fetch_type: 1 | ||
shape: 1 | ||
shape: 2 | ||
shape: 3 | ||
shape: 720 | ||
shape: 1280 | ||
} | ||
fetch_var { | ||
name: "stack_19.tmp_0" | ||
alias_name: "stack_19.tmp_0" | ||
is_lod_tensor: false | ||
fetch_type: 1 | ||
shape: 1 | ||
shape: 3 | ||
shape: 720 | ||
shape: 1280 | ||
} | ||
``` | ||
|
||
## 4. 启动PaddleServing服务 | ||
|
||
``` | ||
cd output_dir/multistagevsrmodel_generator/ | ||
# GPU | ||
python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0 | ||
# CPU | ||
python -m paddle_serving_server.serve --model serving_server --port 9393 | ||
``` | ||
|
||
## 5. 测试部署的服务 | ||
``` | ||
# 进入到导出模型文件夹 | ||
cd output/msvsr/ | ||
``` | ||
|
||
设置`prototxt`文件路径为`serving_client/serving_client_conf.prototxt` 。 | ||
设置`fetch`为`fetch=["stack_19.tmp_0"])` | ||
|
||
测试 | ||
``` | ||
# 进入目录 | ||
cd output/msvsr/ | ||
# 测试代码 test_client.py 会自动创建output文件夹,并在output下生成`res.mp4`文件 | ||
python ../../deploy/serving/test_client.py input_video frame_num | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import sys | ||
import numpy as np | ||
from paddle_serving_client import Client | ||
from paddle_serving_app.reader import * | ||
import cv2 | ||
import os | ||
import imageio | ||
|
||
def get_img(pred): | ||
pred = pred.squeeze() | ||
pred = np.clip(pred, a_min=0., a_max=1.0) | ||
pred = pred * 255 | ||
pred = pred.round() | ||
pred = pred.astype('uint8') | ||
pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc | ||
return pred | ||
|
||
preprocess = Sequential([ | ||
BGR2RGB(), Resize( | ||
(320, 180)), Div(255.0), Transpose( | ||
(2, 0, 1)) | ||
]) | ||
|
||
client = Client() | ||
|
||
client.load_client_config("serving_client/serving_client_conf.prototxt") | ||
client.connect(['127.0.0.1:9393']) | ||
|
||
frame_num = int(sys.argv[2]) | ||
|
||
cap = cv2.VideoCapture(sys.argv[1]) | ||
fps = cap.get(cv2.CAP_PROP_FPS) | ||
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), | ||
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) | ||
success, frame = cap.read() | ||
read_end = False | ||
res_frames = [] | ||
output_dir = "./output" | ||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) | ||
|
||
while success: | ||
frames = [] | ||
for i in range(frame_num): | ||
if success: | ||
frames.append(preprocess(frame)) | ||
success, frame = cap.read() | ||
else: | ||
read_end = True | ||
if read_end: break | ||
|
||
frames = np.stack(frames, axis=0) | ||
fetch_map = client.predict( | ||
feed={ | ||
"lqs": frames, | ||
}, | ||
fetch=["stack_19.tmp_0"], | ||
batch=False) | ||
res_frames.extend([fetch_map["stack_19.tmp_0"][0][i] for i in range(frame_num)]) | ||
|
||
imageio.mimsave("output/output.mp4", | ||
[get_img(frame) for frame in res_frames], | ||
fps=fps) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
Metric psnr: 24.3250 | ||
Metric ssim: 0.6497 | ||
c psnr: 27.2885 | ||
Metric ssim: 0.7969 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.