-
Notifications
You must be signed in to change notification settings - Fork 262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
如何将训练好的mobilenet_v3_small模型转成onnx? #68
Comments
这个还未测试过,下一步计划集成 |
I am also looking forward to the deployment issue of ONNX |
你好,请问完成了吗?万分感谢! |
同问 |
可以试试这个脚本,放到tools/export_onnx.py运行。 需要提前安装相关库: 导出脚本: import argparse
import os
import sys
sys.path.insert(0,os.getcwd())
import torch
import onnx
import onnxsim
from models.build import BuildNet
from utils.train_utils import file2dict
from utils.checkpoint import load_checkpoint
def parse_args():
parser = argparse.ArgumentParser(description='Export ONNX Model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--checkpoint', help='the checkpoint file')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[224, 224],
help='input image size')
parser.add_argument('--opset', type=int, default=12, help='onnx opset version')
args = parser.parse_args()
return args
def main():
# 读取配置文件获取关键字段
args = parse_args()
if len(args.shape) == 1:
input_shape = (args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = args.shape
else:
raise ValueError('invalid input shape')
model_cfg, train_pipeline, val_pipeline, data_cfg, lr_config, optimizer_cfg = file2dict(args.config)
print('Initialize the weights.')
model = BuildNet(model_cfg)
load_checkpoint(model, args.checkpoint, map_location="cpu", strict=True)
print('Exporting onnx model.')
onnx_path = args.checkpoint.replace('.pth', '.onnx')
torch.onnx.export(
model,
(torch.randn(1, 3, *input_shape), False, False),
onnx_path,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
verbose=False,
opset_version=args.opset,
do_constant_folding=True
)
model, ok = onnxsim.simplify(onnx.load(onnx_path))
if not ok:
raise RuntimeError("Onnx simplifying failed.")
onnx.save(model, onnx_path)
if __name__ == "__main__":
main() |
No description provided.
The text was updated successfully, but these errors were encountered: