0
点赞
收藏
分享

微信扫一扫

YOLOV8解读--分类模型训练与预测

YOLOV8命名不同于V5系列,V8不强调yolo这一模型,更加强调其框架属性,所以V8在github上的项目名为 ultralytics 。

所以在下载代码时不要怀疑,就是这个地址:https://github.com/ultralytics/ultralytics

下边写基于V8模型做分类的方法

数据集部分

分类模型的数据集与V5的检测或分割最大的不用试,既不需要配置文件也不需要标注文件。

数据集先直接分为train和val(必须是这两个名称)

YOLOV8解读--分类模型训练与预测_训练代码

然后再在teain和val下再分不同的类别,一个类别一个文件夹,文件夹的名称既是类别的名称,文件夹内全部是图片,没有标注文件或其他。

YOLOV8解读--分类模型训练与预测_分类模型_02

训练部分

训练和预测开始前需要配置好V8所使用的环境,环境依照项目的requirements配置即可。

训练代码如下:

#from ultralytics import RTDETR
from ultralytics import YOLO

if __name__ == "__main__":
    model = YOLO('D:/code/model/ultralytics-main/models/yolov8s-cls.pt')

    # Display model information (optional)
    # model.info()

    # Train the model on the COCO8 example dataset for 100 epochs
    results = model.train(task='classify',mode='train', \
        data='E:/data/classfication_datasets_2', batch=8, imgsz=640, \
            val=False, cache=False, optimizer='auto',cos_lr=True, close_mosaic=25, amp=False, half=False, \
            dnn=False, int8=False, dynamic=False, simplify=False, degrees=30, mosaic=1, augment=True, mixup=0.75, \
            copy_paste=0.5, scale=0.5,epochs=100,device=0,workers=2)

注意 :这里 ‘model = YOLO('D:/code/model/ultralytics-main/models/yolov8s-cls.pt')’ 括号内是提前下载好的预训练模型的地址。预训练模型内已经包含网络结构类型等,所以不需要另外的网络配置文件。

预训练权重下载地址在项目的readme中有说明,如下:

YOLOV8解读--分类模型训练与预测_数据集_03

data中  ‘data='E:/data/classfication_datasets_2',’  这里的地址是数据集目录的地址(train文件夹上一级)。其他参数和V5类似的含义,直接按需求修改(还有一部分参数这里未给出,是默认值,也可以直接设置)。

预测部分

代码如下:

from ultralytics import YOLO
import cv2
import numpy as np
import os
from tqdm import tqdm


def cv_show(name:str,img):
    cv2.namedWindow(name,0)
    cv2.imshow(name, img)
    cv2.waitKey(100)
    # cv2.destroyAllWindows()
    return


class ClsPredict():
    # 预训练权重,训练权重
    def __init__(self,official_model,custom_model) -> None:
        self.model = YOLO(official_model)  # load an official model
        self.model = YOLO(custom_model)  # load a custom model
        
    # 单张预测
    def predict(self,img):

        # 0Predict with the model  
        results = self.model(img) 
        
        # 各类别名称
        names = results[0].names
        # 各类别置信度
        confs = results[0].probs.data.cpu().numpy()
        # 置信度最高的索引
        max_index = np.argmax(confs)

        state = names[max_index]
        score = confs[max_index]

        return state+str(score)
 if __name__ == "__main__":
    # 预训练权重地址(与训练时使用的是同一个)
    official_model = 'D:/code/model/ultralytics-main/models/yolov8s-cls.pt'
    # 训练的得到的权重的地址
    custom_model = 'D:/code/model/ultralytics-main/runs/classify/train6/weights/last.pt'
    
    clspredic = ClsPredict(official_model,custom_model)
    
    img = cv2.imread(img_path)
    results = clspredic.predict(img)  # predict on an image

最终返回的结果包含每个类的名称和每个类的置信度,一一对应。


举报

相关推荐

0 条评论