0
点赞
收藏
分享

微信扫一扫

安全防御(防火墙)

先把模型转到libtorch

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
import time
import sys
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from timm.models import create_model
import time
import model

input_size = 112
num_class = 3
model_name = 'repvit_m0_2'
model_path= 'checkpoints/repvit_m0_2/2024_07_03_21_04_49'

if __name__ == '__main__':


    infer_device = torch.device('cpu')
    model = create_model(
        model_name,
        num_classes=num_class ,
        distillation=False,
        pretrained=False,
    )
    print(model)
    state_dict = torch.load(os.path.join(model_path,'checkpoint_best.pth'), map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict['model'])
    model.to(infer_device)
    img = torch.zeros(1, 3, input_size,input_size).to(infer_device)  
    model.eval()

    y = model(img)  # dry run
    print(y)

    # TorchScript export
    try:
        print('\nStarting to export TorchScript...')
        export_file = 'checkpoint_best.torchscript'  # filename
        trace_model = torch.jit.trace(model, img)
        trace_model.save(os.path.join(model_path,export_file))
        output = trace_model(img)
        print(output)
    except Exception as e:
        print(f'TorchScript export failure: {e}')

    # Finish
    print('\nExport complete')

得到checkpoint_best.torchscript,然后使用pnnx进行转换

./pnnx  RepViT/checkpoints/repvit_m0_2/2024_07_03_21_04_49/checkpoint_best input_types=[f32]

得到checkpoint_best.ncnn.param和checkpoint_best.ncnn.bin,即转换完成

举报

相关推荐

0 条评论