Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

如何将训练好的mobilenet_v3_small模型转成onnx? #68

Open
C-hongfei opened this issue Aug 24, 2023 · 5 comments
Open

如何将训练好的mobilenet_v3_small模型转成onnx? #68

C-hongfei opened this issue Aug 24, 2023 · 5 comments

Comments

@C-hongfei
Copy link

No description provided.

@C-hongfei C-hongfei reopened this Aug 24, 2023
@Fafa-DL
Copy link
Owner

Fafa-DL commented Aug 24, 2023

这个还未测试过,下一步计划集成

@rememberBr
Copy link

I am also looking forward to the deployment issue of ONNX

@FYT-Dworry
Copy link

这个还未测试过,下一步计划集成

你好,请问完成了吗?万分感谢!

@the-cat-crying
Copy link

同问

@PurpleSky-NS
Copy link
Contributor

PurpleSky-NS commented Dec 4, 2024

可以试试这个脚本,放到tools/export_onnx.py运行。
基本用法如下:
python tools/export_onnx.py <config> --checkpoint <checkpoint>
默认导出尺寸为224,也可以修改:
python tools/export_onnx.py <config> --checkpoint <checkpoint> --shape 256 256
也可以修改opset:
python tools/export_onnx.py <config> --checkpoint <checkpoint> --optset 16
onnx生成到checkpoint的同级目录下,并改后缀名为onnx。

需要提前安装相关库:
pip install onnx onnxsim

导出脚本:

import argparse
import os
import sys
sys.path.insert(0,os.getcwd())

import torch
import onnx
import onnxsim

from models.build import BuildNet
from utils.train_utils import file2dict
from utils.checkpoint import load_checkpoint


def parse_args():
    parser = argparse.ArgumentParser(description='Export ONNX Model')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--checkpoint', help='the checkpoint file')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[224, 224],
        help='input image size')
    parser.add_argument('--opset', type=int, default=12, help='onnx opset version')
    args = parser.parse_args()
    return args

def main():
    # 读取配置文件获取关键字段
    args = parse_args()
    
    if len(args.shape) == 1:
        input_shape = (args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = args.shape
    else:
        raise ValueError('invalid input shape')
    
    model_cfg, train_pipeline, val_pipeline, data_cfg, lr_config, optimizer_cfg = file2dict(args.config)
    
    print('Initialize the weights.')
    model = BuildNet(model_cfg)
    load_checkpoint(model, args.checkpoint, map_location="cpu", strict=True)

    print('Exporting onnx model.')
    onnx_path = args.checkpoint.replace('.pth', '.onnx')
    torch.onnx.export(
        model, 
        (torch.randn(1, 3, *input_shape), False, False),
        onnx_path,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
        verbose=False,
        opset_version=args.opset,
        do_constant_folding=True
    )
    model, ok = onnxsim.simplify(onnx.load(onnx_path))
    if not ok:
        raise RuntimeError("Onnx simplifying failed.")
    onnx.save(model, onnx_path)

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants