Skip to content

Commit

Permalink
add export model doc (PaddlePaddle#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
LielinJiang authored Nov 11, 2021
1 parent ee9fae9 commit 283c891
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
36 changes: 36 additions & 0 deletions deploy/export_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# PaddleGAN模型导出教程

## 一、模型导出
本章节介绍如何使用`tools/export_model.py`脚本导出模型。

### 1、启动参数说明

| FLAG | 用途 | 默认值 | 备注 |
|:--------------:|:--------------:|:------------:|:-----------------------------------------:|
| -c | 指定配置文件 | None | |
| --load | 指定加载的模型参数路径 | None | |
| -s|--inputs_size | 指定模型输入形状 | None | |
| --output_dir | 模型保存路径 | `./inference_model` | |

### 2、使用示例

使用训练得到的模型进行试用,这里使用CycleGAN模型为例,脚本如下

```bash
# 下载预训练好的CycleGAN_horse2zebra模型
wget https://paddlegan.bj.bcebos.com/models/CycleGAN_horse2zebra.pdparams

# 导出Cylclegan模型
python -u tools/export_model.py -c configs/cyclegan_horse2zebra.yaml --load CycleGAN_horse2zebra.pdparams --inputs_size="-1,3,-1,-1;-1,3,-1,-1"
```

### 3、config配置说明
```python
export_model:
- {name: 'netG_A', inputs_num: 1}
- {name: 'netG_B', inputs_num: 1}
```
以上为```configs/cyclegan_horse2zebra.yaml```中的配置, 由于```CycleGAN_horse2zebra.pdparams```是个字典,需要制定其中用于导出模型的权重键值。```inputs_num```
为该网络的输入个数。

预测模型会导出到`inference_model/`目录下,分别为`cycleganmodel_netG_A.pdiparams`, `cycleganmodel_netG_A.pdiparams.info`, `cycleganmodel_netG_A.pdmodel`, `cycleganmodel_netG_B.pdiparams`, `cycleganmodel_netG_B.pdiparams.info`, `cycleganmodel_netG_B.pdmodel`,。
16 changes: 11 additions & 5 deletions ppgan/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,18 @@ def set_requires_grad(self, nets, requires_grad=False):
def export_model(self, export_model, output_dir=None, inputs_size=[]):
inputs_num = 0
for net in export_model:
input_spec = [paddle.static.InputSpec(
shape=inputs_size[inputs_num + i], dtype="float32") for i in range(net["inputs_num"])]
input_spec = [
paddle.static.InputSpec(shape=inputs_size[inputs_num + i],
dtype="float32")
for i in range(net["inputs_num"])
]
inputs_num = inputs_num + net["inputs_num"]
static_model = paddle.jit.to_static(self.nets[net["name"]],
input_spec=input_spec)
if output_dir is None:
output_dir = 'export_model'
paddle.jit.save(static_model, os.path.join(
output_dir, '{}_{}'.format(self.__class__.__name__.lower(), net["name"])))
output_dir = 'inference_model'
paddle.jit.save(
static_model,
os.path.join(
output_dir, '{}_{}'.format(self.__class__.__name__.lower(),
net["name"])))
13 changes: 7 additions & 6 deletions tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--export_model",
default=None,
type=str,
help="The path prefix of inference model to be used.", )
parser.add_argument('-c',
'--config-file',
metavar="FILE",
Expand All @@ -50,6 +45,12 @@ def parse_args():
default=None,
required=True,
help="the inputs size")
parser.add_argument(
"--output_dir",
default=None,
type=str,
help="The path prefix of inference model to be used.",
)
args = parser.parse_args()
return args

Expand All @@ -63,7 +64,7 @@ def main(args, cfg):
for net_name, net in model.nets.items():
if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name])
model.export_model(cfg.export_model, args.export_model, inputs_size)
model.export_model(cfg.export_model, args.output_dir, inputs_size)


if __name__ == "__main__":
Expand Down

0 comments on commit 283c891

Please sign in to comment.