0
点赞
收藏
分享

微信扫一扫

SVM简单应用python代码 dogs vs cats

菜头粿子园 2022-02-27 阅读 48
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import LinearSVC
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from IPython.display import Image
from imutils import paths
import numpy as np
import cv2
import os



def extract_histogram(image, bins=(8, 8, 8)):
    hist = cv2.calcHist([image], [0, 1, 2], None, bins, [0, 256, 0, 256, 0, 256])
    cv2.normalize(hist, hist)
    return hist.flatten()

imagePaths = sorted(list(paths.list_images('./data/train')))
data = []
labels = []

for (i, imagePath) in enumerate(imagePaths):
    image = cv2.imread(imagePath, 1)
    label = imagePath.split(os.path.sep)[-1].split(".")[0]
    hist = extract_histogram(image)
    data.append(hist)
    labels.append(label)

le = LabelEncoder()
labels = le.fit_transform(labels)

(trainData, testData, trainLabels, testLabels) = train_test_split(np.array(data), labels, test_size=0.25, random_state=2)

model = LinearSVC(random_state = 2, C = 0.94)
model.fit(trainData, trainLabels)
#1 
print(np.round(model.coef_[0][280],2))
#2
print(np.round(model.coef_[0][129],2))
#3
print(np.round(model.coef_[0][440],2))
#4
from sklearn.metrics import f1_score
predictions = model.predict(testData)
print(np.round(f1_score(testLabels, predictions, average='macro'),2))
#print(classification_report(testLabels, predictions, target_names=le.classes_))

#5
singleImage = cv2.imread('./data/test/cat.1016.jpg')
histt = extract_histogram(singleImage)
histt2 = histt.reshape(1, -1)
prediction = model.predict(histt2)
print(prediction)
#6
singleImage = cv2.imread('./data/test/cat.1024.jpg')
histt = extract_histogram(singleImage)
histt2 = histt.reshape(1, -1)
prediction = model.predict(histt2)
print(prediction)
#7
singleImage = cv2.imread('./data/test/dog.1006.jpg')
histt = extract_histogram(singleImage)
histt2 = histt.reshape(1, -1)
prediction = model.predict(histt2)
print(prediction)
#8
singleImage = cv2.imread('./data/test/dog.1033.jpg')
histt = extract_histogram(singleImage)
histt2 = histt.reshape(1, -1)
prediction = model.predict(histt2)
print(prediction)
举报

相关推荐

0 条评论