0
点赞
收藏
分享

微信扫一扫

41、使用mmrotate进行旋转目标检测,并进行ncnn和mnn部署

Hyggelook 2022-04-13 阅读 106
python

基本思想:仍然是身份证分割,并进行ncnn和mnn的c++的部署开发

 第一步下载源码

ubuntu@ubuntu-Super-Server:~/sxj731533730$ git clone https://github.com/open-mmlab/mmrotate.git

Cloning into 'mmrotate'...
remote: Enumerating objects: 23529, done.
remote: Total 23529 (delta 0), reused 0 (delta 0), pack-reused 23529
Receiving objects: 100% (23529/23529), 35.29 MiB | 3.54 MiB/s, done.
Resolving deltas: 100% (16443/16443), done.

ubuntu@ubuntu-Super-Server:~/sxj731533730$ cd mmrotate/
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ pip3 install -r requirements.txt

第二步:下载模型和测试

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotaten$ mkdir model
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ cd model/
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate/model$ cat ../configs/kfiou/README.md
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate/model$ axel -n 100 https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc/r3det_kfiou_ln_r50_fpn_1x_dota_oc-8e7f049d.pth
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate/model$ pip install mmcv-full==1.4.6 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html

然后修改一下测试程序

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate/demo$ python3 image_demo.py demo.jpg ../configs/kfiou/r3det_kfiou_ln_
r50_fpn_1x_dota_oc.py ../model/r3det_kfiou_ln_r50_fpn_1x_dota_oc-8e7f049d.pth sxj731533730.jpg

第三步:首先开始造数据集,比如我只标注了四张身份证图片,且为4个坐标点图 一个反面 三个反面 一个正面 三个正面的 (正面的图就不上了)

标注方式参考官网解释  https://captain-whu.github.io/DOTA/index.html

x1, y1, x2, y2, x3, y3, x4, y4:四边形的四个顶点的坐标 顶点按顺时针顺序排列,第一个起点为左上第一个点

使用的labelme 标注的,也可以使用rolabelimg标注~

 使用旋转代码,进行生成批量角度图片

# -*- coding: utf-8 -*-
import os
import sys
import json
import io
import random
import re
import cv2
import numpy as np
from random import choice
import math

source_path = r'A'
destination_path = r'B'
angle=[]
for item in range(0,360,15):
    angle.append(item)

for angle_item in angle:
    article_info = {}
    data_json = json.loads(json.dumps(article_info))
    data_json['version'] = '5.0.1'
    data_json['flags'] = {}

    data_json["lineColor"] = [
        0,
        255,
        0,
        128
    ]
    data_json["fillColor"] = [
        255,
        0,
        0,
        128
    ]


    def file_name(file_dir):
        L = []
        for root, dirs, files in os.walk(file_dir):
            for file in files:
                if os.path.splitext(file)[1] == '.json':
                    L.append(os.path.join(root, file))
            return L


    def rotation_point(img, angle,pts):
        cols = img.shape[1]
        rows = img.shape[0]
        M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
        heightNew = int(cols * math.fabs(math.sin(math.radians(angle))) + rows * math.fabs(math.cos(math.radians(angle))))
        widthNew = int(rows * math.fabs(math.sin(math.radians(angle))) + cols * math.fabs(math.cos(math.radians(angle))))
        M[0, 2] += (widthNew - cols) / 2
        M[1, 2] += (heightNew - rows) / 2
        img = cv2.warpAffine(img, M, (widthNew, heightNew))
        pts = cv2.transform(np.asarray(pts, dtype=np.float64).reshape((-1, 1, 2)), M)
        return img,pts


    for name in enumerate(file_name(source_path)):
        shape_json = []
        m_path = name[1]
        dir = os.path.dirname(m_path)
        file_json = io.open(m_path, 'r', encoding='utf-8')
        json_data = file_json.read()
        data = json.loads(json_data)
        data_json['imageData'] = None
        data_name = data['imagePath']
        data_path = os.path.join(dir ,data_name)
        object_name = os.path.splitext(data['imagePath'])[0]
        data_new_json_name =None
        list_point=[]
        for i in range(len(data['shapes'])):
            m_name_0 = data['shapes'][i]['label']
            print('m_name_0=', m_name_0)
            item_point = []
            for j in range(len(data['shapes'][i]['points'])):
                print(data['shapes'][i]['points'][j][0],data['shapes'][i]['points'][j][1])
                item_point.append([data['shapes'][i]['points'][j][0],data['shapes'][i]['points'][j][1]])



            data_json_fill_color=None
            data_json_rec = data['shapes'][i]['shape_type']
            img = cv2.imread(data_path)

            im_rotate,item_point = rotation_point(img, angle_item,np.asarray(item_point))
            item_point=np.squeeze(item_point).tolist()
            print(item_point)
            (filename, extension) = os.path.splitext(data_name)
            data_new_picture_name = os.path.join(destination_path ,filename+ ".".join([str(angle_item), "jpg"]))
            data_new_json_name = os.path.join(destination_path , filename+".".join([str(angle_item), "json"]))
            data_json['imagePath'] = filename + ".".join([str(angle_item) , "jpg"])
            cv2.imwrite(data_new_picture_name, im_rotate)
            im_rotate = cv2.imread(data_new_picture_name)
            data_json['imageWidth'] = im_rotate.shape[1]
            data_json['imageHeight'] = im_rotate.shape[0]
            shape_json_item = {"label": m_name_0,
                               "points": item_point, "shape_type": data_json_rec}
            shape_json.append(shape_json_item)
        data_json['shapes'] = shape_json
        data_info = json.dumps(data_json, ensure_ascii=False)
        fp = open(data_new_json_name, "w+")
        json.dump(data_info, fp, ensure_ascii=False, indent=4)
        fp.close()
        fp = open(data_new_json_name, "r")
        for x in fp.readlines():
            y = x.replace("\\\"", "\"")
            z = y.replace("\"{", "{")
            w = z.replace("}\"", "}")
            fp.close()
            fp = open(data_new_json_name, "w+")
            fp.write(w)
            fp.close()

只贴两张

第四步:训练我们自己的数据集

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ mkdir -p kfiouDataSets

里面是标准的jpg和json(labelme标注的格式)

​ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate/kfiouDataSets$ tree  -L 1

├── trainDataset #里面含有jpg和对用json文件

└── valDataset  #里面含有jpg和对用json文件

└── testDataset  #里面含有jpg和对用json文件

3 directories, 0 files

生成 DOTA数据集(参考附录作者修改了下面脚本) labelme2dota.py

生成方式格式 每个图片目标对应一个txt      解释权 https://captain-whu.github.io/DOTA/index.html

x1, y1, x2, y2, x3, y3, x4, y4, category, difficult
x1, y1, x2, y2, x3, y3, x4, y4, category, difficult
category: 标签名字
difficult:表示标签检测的难易程度 (1表示困难,0表示不困难)

转换脚本 labelme2dota ,我担心顺时针排序代码写的有问题,也做了反向转换验证,因为 mmrotate 代码里训练需要png 图片 还是不动代码了 改图片格式吧

import json
import os
from glob import glob
import argparse
import numpy as np
import shutil
from PIL import Image
import cv2

# convert labelme json to DOTA txt format
# convert DOTA json to lableme txt format
def custombasename(fullname):
    return os.path.basename(os.path.splitext(fullname)[0])


def order_points_new(pts):  # clock -https://zhuanlan.zhihu.com/p/10643062
    # sort the points based on their x-coordinates
    xSorted = pts[np.argsort(pts[:, 0]), :]

    # grab the left-most and right-most points from the sorted
    # x-roodinate points
    leftMost = xSorted[:2, :]
    rightMost = xSorted[2:, :]
    if leftMost[0, 1] != leftMost[1, 1]:
        leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
    else:
        leftMost = leftMost[np.argsort(leftMost[:, 0])[::-1], :]
    (tl, bl) = leftMost
    if rightMost[0, 1] != rightMost[1, 1]:
        rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
    else:
        rightMost = rightMost[np.argsort(rightMost[:, 0])[::-1], :]
    (tr, br) = rightMost
    # return the coordinates in top-left, top-right,
    # bottom-right, and bottom-left order
    return np.array([tl, tr, br, bl], dtype="float32")


parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('--input_dir', default=r'C:\Users\PHILIPS\Desktop\source', type=str,
                    help='input annotated directory')
parser.add_argument('--output_images', default=r'C:\Users\PHILIPS\Desktop\datasets\images', type=str,
                    help='input annotated directory')
parser.add_argument('--output_dir', default=r'C:\Users\PHILIPS\Desktop\datasets\labelTxt', type=str, help='output dataset directory')
parser.add_argument('--verify_dir', default=r'C:\Users\PHILIPS\Desktop\datasets\verify', type=str,
                    help='verify dataset directory')
parser.add_argument('--verify', default=True, type=bool, help='verify')
parser.add_argument('--labels', default=r'labels.txt', type=str, help='labels annotated directory')
args = parser.parse_args()

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
if not os.path.exists(args.output_images):
    os.makedirs(args.output_images)
print('Creating dataset:', args.output_dir)

file_list = glob(os.path.join(args.input_dir, ".".join(["*", "json"])))

for i in range(len(file_list)):
    with open(file_list[i]) as f:
        label_str = f.read()
        label_dict = json.loads(label_str)  # json文件读入dict

        # 输出 txt 文件的路径
        out_file = os.path.join(args.output_dir, ".".join([custombasename(file_list[i]), 'txt']))
        # 写入 poly 四点坐标 和 label
        fout = open(out_file, 'w')
        out_str = ''
        # np.array(box, dtype="int")
        for shape_dict in label_dict['shapes']:
            points = shape_dict['points']
            item_points = []
            for p in points:
                item_points.append([p[0], p[1]])
            item_points = order_points_new(np.array(item_points, dtype="float"))
            for p in item_points.tolist():
                out_str += (str(p[0]) + ' ' + str(p[1]) + ' ')
            out_str += shape_dict['label'] + ' 0\n'
        fout.write(out_str)
        fout.close()
    print('%d/%d' % (i + 1, len(file_list)))
    print("labelme2dota...")
if args.verify:
    if not os.path.exists(args.verify_dir):
        os.makedirs(args.verify_dir)
    txt_list = glob(os.path.join(args.output_dir, ".".join(["*", "txt"])))
    for i in range(len(txt_list)):
        (filepath, tempfilename) = os.path.split(txt_list[i])
        (filename, extension) = os.path.splitext(tempfilename)
        sourcePath = None
        image_filename = None
        if os.path.exists(os.path.join(args.input_dir, ".".join([filename, "jpg"]))):
            sourcePath = os.path.join(args.input_dir, ".".join([filename, "jpg"]))
            image_filename = ".".join([filename, "png"])
        elif os.path.exists(os.path.join(args.input_dir, ".".join([filename, "png"]))):
            sourcePath = os.path.join(args.input_dir, ".".join([filename, "png"]))
            image_filename = ".".join([filename, "png"])
        if sourcePath is None:
            print("check photo type")
            continue
        targetPath = os.path.join(args.verify_dir,image_filename)
        targetpng = os.path.join(args.output_images, image_filename)
        shutil.copy(sourcePath, targetPath)
        shutil.copy(sourcePath, targetpng)
        img = Image.open(sourcePath)
        imgSize = img.size  # 大小/尺寸
        w = img.width  # 图片的宽
        h = img.height  # 图片的高

        data = {}
        data['imagePath'] = image_filename
        data['flags'] = {}
        data['imageWidth'] = w
        data['imageHeight'] = h
        data['imageData'] = None
        data['version'] = "5.0.1"
        data["shapes"] = []

        with open(txt_list[i]) as f:
            label_str = f.readlines()
            for label_item in label_str:
                line_char = label_item.split("\n")[0].split(' ')
                points = [[eval(line_char[0]), eval(line_char[1])], [eval(line_char[2]), eval(line_char[3])],
                          [eval(line_char[4]), eval(line_char[5])], [eval(line_char[6]), eval(line_char[7])]]
                itemData = {'points': []}
                itemData['points'].extend(points)
                itemData["flag"] = {}
                itemData["group_id"] = None
                itemData["shape_type"] = "polygon"
                itemData["label"] = line_char[-2]
                data["shapes"].append(itemData)

            jsonName = ".".join([filename, "json"])
            jsonPath = os.path.join(args.verify_dir, jsonName)
            with open(jsonPath, "w") as f:
                json.dump(data, f)
            print(jsonName)
            print("dota2labelme...")

执行命令,最好在本地用该脚本,熟悉一下转换方式

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate/yolactDataSets$ sudo vim labels.txt
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate/yolactDataSets$ cat labels.txt
card

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ python3 labelme2dota.py --input_dir kfiouDataSets/trainDataset/ --output_dir kfiouDataSets/train --verify false --labels kfiouDataSets/labels.txt
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ python3 labelme2dota.py --input_dir kfiouDataSets/valDataset/ --output_dir kfiouDataSets/val --verify false --labels kfiouDataSets/labels.txt
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ python3 labelme2dota.py --input_dir kfiouDataSets/testDataset/ --output_dir kfiouDataSets/test --verify false --labels kfiouDataSets/labels.txt

划分训练和验证 测试集

import os
from glob import glob
import shutil
import random


ann_txt=r"C:\Users\PHILIPS\Desktop\dest"
train_annfile=r"C:\Users\PHILIPS\Desktop\train_annfile"
test_annfile=r"C:\Users\PHILIPS\Desktop\test_annfile"
val_annfile=r"C:\Users\PHILIPS\Desktop\val_annfile"

train_num=0.8
test_num=0.2
val_num=0.2


if not os.path.exists(train_annfile):
    os.makedirs(train_annfile)

if not os.path.exists(test_annfile):
    os.makedirs(test_annfile)
if not os.path.exists(val_annfile):
    os.makedirs(val_annfile)

txt_list = glob(os.path.join(ann_txt, ".".join(["*", "txt"])))
train_list = random.sample(txt_list, int(train_num*len(txt_list)))
val_list = random.sample(txt_list, int(val_num*len(txt_list)))
test_list = random.sample(txt_list, int(test_num*len(txt_list)))

for idx,item in enumerate(txt_list):
    source_file_txt=os.path.join(ann_txt,item)
    if item in train_list:
        shutil.copy(source_file_txt, train_annfile)
    elif item in val_list:
        shutil.copy(source_file_txt, val_annfile)
    else:
        shutil.copy(source_file_txt, test_annfile)
    print("copy txt into file")
print("complish")

整个目录结构

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ tree -L 1
.
├── CITATION.cff
├── configs
├── demo
├── docker
├── docs
├── generate.py
├── kfiouDataSets
├── LICENSE
├── MANIFEST.in
├── mmrotate
├── model
├── model-index.yml
├── README.md
├── README_zh-CN.md
├── requirements
├── requirements.txt
├── resources
├── setup.cfg
├── setup.py
├── tests
└── tools

11 directories, 10 files

修改配置文件参数和增加参数,目录

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate/kfiouDataSets$ tree -L 1
.
├── labels.txt
├── source
├── test_annfile
├── train_annfile
└── val_annfile

4 directories, 1 file

修改配置文件

home/ubuntu/sxj731533730/mmrotate/configs/_base_/datasets/dotav1.py

变更数据集的根目录和训练集 测试集 验证集的目录

# dataset settings
dataset_type = 'DOTADataset'
data_root = '/home/ubuntu/sxj731533730/mmrotate/kfiouDataSets/'
。。。。。
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        #classes=('card',),
        ann_file=data_root + 'train_annfile/',
        img_prefix=data_root + 'source/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        #classes=('card',),
        ann_file=data_root + 'val_annfile/',
        img_prefix=data_root + 'source/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        #classes=('card',),
        ann_file=data_root + 'test_annfile/',
        img_prefix=data_root + 'source/',
        pipeline=test_pipeline))

修改一下标签

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ sudo vim mmrotate/datasets/dota.py

class DOTADataset(CustomDataset):
    """DOTA dataset for detection.

    Args:
        ann_file (str): Annotation file path.
        pipeline (list[dict]): Processing pipeline.
        version (str, optional): Angle representations. Defaults to 'oc'.
        difficulty (bool, optional): The difficulty threshold of GT.
    """
    '''
    CLASSES = ('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
               'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
               'basketball-court', 'storage-tank', 'soccer-ball-field',
               'roundabout', 'harbor', 'swimming-pool', 'helicopter')
    '''
    CLASSES=('card',) # need to add ,


修改配置文件

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ sudo vim configs/r3det/r3det_r50_fpn_1x_dota_oc.py
ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ sudo vim  configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py

num_classes=1  # 

建立一个存储日志和模型的文件夹

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ mkdir run

第五步:开始训练模型

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ python3 tools/train.py configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py --work-dir=run

训练的过程

2022-03-22 06:21:11,014 - mmrotate - INFO - workflow: [('train', 1)], max: 12 epochs
2022-03-22 06:21:11,014 - mmrotate - INFO - Checkpoints will be saved to /home/ubuntu/sxj731533730/mmrotate/run by HardDiskBackend.
2022-03-22 06:21:35,543 - mmrotate - INFO - Epoch [1][50/330]   lr: 9.967e-04, eta: 0:31:57, time: 0.491, data_time: 0.050, memory: 3442, s0.loss_cls: 1.1651, s0.loss_bbox: 6.7060, sr0.loss_cls: 1.1758, sr0.loss_bbox: 6.2110, loss: 15.2579, grad_norm: 8.4750
2022-03-22 06:21:57,744 - mmrotate - INFO - Epoch [1][100/330]  lr: 1.163e-03, eta: 0:30:03, time: 0.444, data_time: 0.005, memory: 3442, s0.loss_cls: 0.9773, s0.loss_bbox: 6.1366, sr0.loss_cls: 0.4372, sr0.loss_bbox: 5.9701, loss: 13.5212, grad_norm: 26.8998
2022-03-22 06:22:19,983 - mmrotate - INFO - Epoch [1][150/330]  lr: 1.330e-03, eta: 0:29:11, time: 0.445, data_time: 0.005, memory: 3442, s0.loss_cls: 0.3552, s0.loss_bbox: 5.9312, sr0.loss_cls: 0.2589, sr0.loss_bbox: 5.9402, loss: 12.4856, grad_norm: 18.0854
2022-03-22 06:22:42,092 - mmrotate - INFO - Epoch [1][200/330]  lr: 1.497e-03, eta: 0:28:32, time: 0.442, data_time: 0.005, memory: 3442, s0.loss_cls: 0.2930, s0.loss_bbox: 5.8729, sr0.loss_cls: 0.2286, sr0.loss_bbox: 5.9325, loss: 12.3271, grad_norm: 15.2396
2022-03-22 06:23:04,421 - mmrotate - INFO - Epoch [1][250/330]  lr: 1.663e-03, eta: 0:28:02, time: 0.447, data_time: 0.005, memory: 3442, s0.loss_cls: 0.2843, s0.loss_bbox: 5.8343, sr0.loss_cls: 0.2799, sr0.loss_bbox: 5.8648, loss: 12.2633, grad_norm: 12.3675
2022-03-22 06:23:26,901 - mmrotate - INFO - Epoch [1][300/330]  lr: 1.830e-03, eta: 0:27:37, time: 0.450, data_time: 0.005, memory: 3442, s0.loss_cls: 0.2462, s0.loss_bbox: 5.7725, sr0.loss_cls: 0.0950, sr0.loss_bbox: 5.8509, loss: 11.9646, grad_norm: 9.6189
2022-03-22 06:24:04,773 - mmrotate - INFO - Epoch [2][50/330]   lr: 2.097e-03, eta: 0:25:11, time: 0.491, data_time: 0.049, memory: 3442, s0.loss_cls: 0.1637, s0.loss_bbox: 5.6981, sr0.loss_cls: 0.0575, sr0.loss_bbox: 5.8003, loss: 11.7196, grad_norm: 6.5400
......
2022-03-23 01:48:42,354 - mmrotate - INFO -
+-------+-----+------+--------+-------+
| class | gts | dets | recall | ap    |
+-------+-----+------+--------+-------+
| card  | 50  | 53   | 1.000  | 1.000 |
+-------+-----+------+--------+-------+
| mAP   |     |      |        | 1.000 |
+-------+-----+------+--------+-------+
2022-03-23 01:48:42,385 - mmrotate - INFO - Exp name: r3det_kfiou_ln_r50_fpn_1x_dota_oc.py
2022-03-23 01:48:42,386 - mmrotate - INFO - Epoch(val) [12][28] mAP: 1.0000

 第l六步:测试一下 我只训练12次而已

ubuntu@ubuntu-Super-Server:~/sxj731533730/mmrotate$ python3 demo/image_demo.py kfiouDataSets/source/cap_output_00_03_30_30210.jpg configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py run/epoch_12.pth demo/sxj731533730.jpg

将身份证生成分割图片写入本地显示 (还未训练好,继续训练中)

 第六步:转mmrotate模型到onnx过程中遇到了涉及到mmcv函数,好难啊,官方暂时没提供转换脚本,只能抄袭mmdetection~

import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint

from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector
from argparse import ArgumentParser

from mmdet.apis import inference_detector, init_detector
import torch
# import trochvision
import torch.utils.data
import argparse
import onnxruntime
import os
import cv2
import numpy as np
from onnxruntime.datasets import get_example

import mmrotate  # noqa: F401
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")




def processimage(imgs,model):
    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

    cfg = model.cfg
    device = next(model.parameters()).device  # model device

    if isinstance(imgs[0], np.ndarray):
        cfg = cfg.copy()
        # set loading pipeline type
        cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'

    cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
    test_pipeline = Compose(cfg.data.test.pipeline)

    datas = []
    for img in imgs:
        # prepare data
        if isinstance(img, np.ndarray):
            # directly add img
            data = dict(img=img)
        else:
            # add information into dict
            data = dict(img_info=dict(filename=img), img_prefix=None)
        # build the data pipeline
        data = test_pipeline(data)
        datas.append(data)

    data = collate(datas, samples_per_gpu=len(imgs))
    # just get the actual data from DataContainer
    data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
    data['img'] = [img.data[0] for img in data['img']]
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device])[0]
    else:
        for m in model.modules():
            assert not isinstance(
                m, RoIPool
            ), 'CPU inference with RoIPool is not supported currently.'
    return data
def torch2onnx(args, model):


    dummy_input=processimage(args.img,model)

    model.to(device)
    model.forward = model.forward_dummy
    input_names = ["input"]  # 模型输入的name
    output_names = ["output"]  # 模型输出的name
    print("====", dummy_input,type(dummy_input) )

    torch_out = torch.onnx.export(model, dummy_input['img'][0], args.onnx_model_path, input_names=input_names,output_names=output_names,export_params=True,
            keep_initializers_as_inputs=True,
            do_constant_folding=True,
            verbose=False,
            opset_version=11)
    # test onnx model
    example_model = get_example(args.onnx_model_path)
    session = onnxruntime.InferenceSession(example_model)
    # get the name of the first input of the model
    input_name = session.get_inputs()[0].name
    # print('onnx Input Name:', input_name)
    result = session.run([], {input_name: dummy_input.data.cpu().numpy()})
    result0 = torch.tensor(result[0], dtype=torch.float32)

    results = (result0)
    print(results)


def main():
    """Test a single image."""
    parser = ArgumentParser()
    parser.add_argument('--img',default=r"G:\mmrotate\kfiouDataSets\source\cap_output_00_03_30_30210.jpg", help='Image file')
    parser.add_argument('--config',default=r"G:/mmrotate/configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py", help='Config file')
    parser.add_argument('--checkpoint', default=r"G:\mmrotate\demo\epoch_12.pth",help='Checkpoint file')
    parser.add_argument('--onnx_model_path',default=r"G:\mmrotate\demo\epoch_12.onnx", help='onnx_model_path')
    parser.add_argument('--device', default='cuda:0', help='Device used for inference')
    parser.add_argument('--output', default=r'G:\mmrotate\demo\sxj731533730.jpg', help='Output image')
    parser.add_argument( '--score-thr', type=float, default=0.3, help='bbox score threshold')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)

    # test a single image
    result = inference_detector(model, args.img)
    # show the results
    model.show_result(
        args.img, result, score_thr=args.score_thr, out_file=args.output)

    torch2onnx(args, model)

if __name__ == '__main__':
    main()

但是转出来的有问题,涉及到mmcv函数,

需要修改一下源码

举报

相关推荐

0 条评论