0
点赞
收藏
分享

微信扫一扫

机器学习笔记 - 使用Detecto构建自定义对象检测

一、Detecto简介

        Detecto 是一个 Python 包,只需 5 行代码,您就可以构建功能齐全的计算机视觉和对象检测模型。静态图像和视频的推理、自定义数据集的迁移学习以及将模型序列化为文件只是 Detecto 的一些功能。Detecto 也建立在 PyTorch 之上,允许在两个库之间轻松传输模型。

        Detecto 的强大之处在于其简单易用。 从 PyTorch 的模型动物园创建和运行预训练的 Faster R-CNN ResNet-50 FPN 需要 4 行代码:

from detecto.core import Model
from detecto.visualize import detect_video

model = Model()  # Initialize a pre-trained model
detect_video(model, 'input_video.mp4', 'output.avi')  # Run inference on a video

        使用pip命令可以进行安装

pip install detecto

二、构建自定义对象检测模型

1、数据集

        我们这里使用一个狗狗的数据集,数据集下载地址。该数据集已由 Detecto 的开发人员标记,我们只需将其导入到我们的环境中。该数据集包含 300 张带标签的金毛猎犬和吉娃娃犬图像。

        数据集内有三个文件夹images、train_labels、val_labels分别包含了图像文件,训练集的标签,验证集的标签。

2、数据预处理

(1)转换标签及可视化图片

from detecto import utils
import matplotlib.pyplot as plt
import matplotlib.image as img
from torchvision import transforms
from detecto import core
from detecto import visualize

utils.xml_to_csv('train_labels', 'train.csv')
utils.xml_to_csv('val_labels', 'val.csv')

image = img.imread('images/n02085620_8611.jpg')
plt.imshow(image)
plt.show()

 (2)进行图片变换及可视化

transform_img = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(800),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    utils.normalize_transform(),
])

dataset = core.Dataset('train.csv', 'images/', transform=transform_img)

image, information = dataset[50]
visualize.show_labeled_image(image, information['boxes'], information['labels'])

3、训练模型并进行预测

        这里使用 Faster-RCNN-resnet50 模型。core.DataLoader默认下载resnet50模型,另外目前还可以选mobilenet_v3和mobilenet_v3_320两种。

dataloader = core.DataLoader(dataset)
validation_data = core.Dataset('val.csv', 'images/')
categories = ['Chihuahua', 'golden_retriever']
classifier = core.Model(categories)
history = classifier.fit(dataloader, validation_data, epochs = 20, verbose = True)
plt.plot(history)


images = []
for i in range(0,36,3):
  image,_ = validation_data[i]
  images.append(image)

visualize.plot_prediction_grid(classifier, images, dim=(4, 3), figsize=(16, 12))

        训练输出 

          预测结果如下 

 四、总结

        Detecto库用起来还是很方便的,不过版本兼容性等还是有一些问题,运行的时候会找很多版本,dll等问题,如果有兴趣的可以看看它的github。

https://github.com/alankbi/detectoicon-default.png?t=M3C8https://github.com/alankbi/detecto

举报

相关推荐

0 条评论