视觉机器学习20讲-MATLAB源码示例(9)-SVM算法)
1. SVM算法
SVM(support Vector Mac)又称为支持向量机,是一种二分类的模型。当然如果进行设计调整也可以用于多类别问题的分类。支持向量机可以分为线性核非线性两大类。其主要思想为找到空间中的一个更够将所有数据样本划开的超平面,并且使得本本集中所有数据到这个超平面的距离最短。
支持向量机(SVM)是从数据中找出一个数据的分割超平面。将两个类别的数据完全分割开,并且在模型构建的过程中,保证分割区间最大化。
(1)线性可分:在数据集中,如果可以找出一个超平面,将两组数据分开,那么这个数据集叫做线性可分数据。
(2)线性不可分:在数据集中,没法找出一个超平面,能够将两组数据分开,那么这个数据集就叫做线性不可分数据。
(3)分割超平面:将数据集分割开来的直线/平面叫做分割超平面。
(4)支持向量:离分割超平面最近的那些点叫做支持向量。
(5)间隔:支持向量数据点到分割超平面的距离称为间隔。
(6)KKT条件:KKT条件是泛拉格朗日乘子法的一种形式;主要应用在当我们的优化函数存在不等值约束的情况下的一种最优化求解方式;KKT条件即满足不等式约束情况下的。
2. Matlab仿真
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%功能:演示SVM算法在计算机视觉中的应用
%基于SVM实现特征分类;
%环境:Win7,Matlab2018a
%Modi: C.C
%时间:2022-4-5
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% kernel function
C = 200;
ker = struct('type','linear');
% test sample
n = 50;
randn('state',6);
x1 = randn(2,n);
y1 = ones(1,n);
x2 = 5+randn(2,n);
y2 = -ones(1,n);
figure(1);
plot(x1(1,:),x1(2,:),'bx',x2(1,:),x2(2,:),'k.');
axis([-3 8 -3 8]);
title('C-SVC')
hold on;
X = [x1,x2];
Y = [y1,y2];
% train SVM
tic
svm = svmTrain('svc_c',X,Y,ker,C);
t_train = toc
% find sustain vector
a = svm.a;
epsilon = 1e-8;
i_sv = find(abs(a)>epsilon);
plot(X(1,i_sv),X(2,i_sv),'ro');
% test output
[x1,x2] = meshgrid(-2:0.1:7,-2:0.1:7);
[rows,cols] = size(x1);
nt = rows*cols;
Xt = [reshape(x1,1,nt);reshape(x2,1,nt)];
tic
Yd = svmSim(svm,Xt);
t_sim = toc
Yd = reshape(Yd,rows,cols);
contour(x1,x2,Yd,[0 0],'m');
hold off;
function svm = svmTrain(svmType,X,Y,ker,p1,p2)
% SVM Classification:
% svm = svmTrain('svc_c',x,y,ker,C);
% svm = svmTrain('svc_nu',x,y,ker,nu);
%
% One-Class SVM:
% svm = svmTrain('svm_one_class',x,[],ker,nu);
%
% SVM Regression:
% svm = svmTrain('svr_epsilon',x,y,ker,C,e);
% svm = svmTrain('svr_nu',x,y,ker,C,nu);
options = optimset;
options.LargeScale = 'off';
options.Display = 'off';
switch svmType
case 'svc_c',
C = p1;
n = length(Y);
H = (Y'*Y).*kernel(ker,X,X);
f = -ones(n,1);
A = [];
b = [];
Aeq = Y;
beq = 0;
lb = zeros(n,1);
ub = C*ones(n,1);
a0 = zeros(n,1);
[a,fval,eXitflag,output,lambda] = quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options);
case 'svc_nu',
nu = p1;
n = length(Y);
H = (Y'*Y).*kernel(ker,X,X);
f = zeros(n,1);
A = -ones(1,n);
b = -nu;
Aeq = Y;
beq = 0;
lb = zeros(n,1);
ub = ones(n,1)/n;
a0 = zeros(n,1);
[a,fval,eXitflag,output,lambda] = quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options);
case 'svm_one_class',
nu = p1;
n = size(X,2);
H = kernel(ker,X,X);
f = zeros(n,1);
for i = 1:n
f(i,:) = -kernel(ker,X(:,i),X(:,i));
end
A = [];
b = [];
Aeq = ones(1,n);
beq = 1;
lb = zeros(n,1);
ub = ones(n,1)/(nu*n);
a0 = zeros(n,1);
[a,fval,eXitflag,output,lambda] = quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options);
case 'svr_epsilon',
C = p1;
e = p2;
n = length(Y);
Q = kernel(ker,X,X);
H = [Q,-Q;-Q,Q];
f = [e*ones(n,1)-Y';e*ones(n,1)+Y'];
%f = [e*ones(n,1)+Y';e*ones(n,1)-Y'];
A = [];
b = [];
Aeq = [ones(1,n),-ones(1,n)];
beq = 0;
lb = zeros(2*n,1);
ub = C*ones(2*n,1);
a0 = zeros(2*n,1);
[a,fval,eXitflag,output,lambda] = quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options);
a = a(1:n)-a(n+1:end);
case 'svr_nu',
C = p1;
nu = p2;
n = length(Y);
Q = kernel(ker,X,X);
H = [Q,-Q;-Q,Q];
f = [-Y';+Y'];
%f = [+Y';-Y'];
A = [];
b = [];
Aeq = [ones(1,n),-ones(1,n);ones(1,2*n)];
beq = [0;C*n*nu];
lb = zeros(2*n,1);
ub = C*ones(2*n,1);
a0 = zeros(2*n,1);
[a,fval,eXitflag,output,lambda] = quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options);
a = a(1:n)-a(n+1:end);
otherwise,
end
eXitflag
% output svm
svm.type = svmType;
svm.ker = ker;
svm.x = X;
svm.y = Y;
svm.a = a';
function Yd = svmSim(svm,Xt)
% parameters
cathe = 10e+6;
nx = size(svm.x,2);
nt = size(Xt,2);
block = ceil(nx*nt/cathe);
num = ceil(nt/block);
for i = 1:block
if (i==block)
index = [(i-1)*num+1:nt];
else
index = (i-1)*num+[1:num];
end
Yd(index) = svmSim_block(svm,Xt(:,index));
end
% ------------------------------------------------------------%
function Yd = svmSim_block(svm,Xt);
type = svm.type;
ker = svm.ker;
X = svm.x;
Y = svm.y;
a = svm.a;
% test output
epsilon = 1e-8;
i_sv = find(abs(a)>epsilon);
switch type
case 'svc_c',
tmp = (a.*Y)*kernel(ker,X,X(:,i_sv));
b = Y(i_sv)-tmp;
b = mean(b);
tmp = (a.*Y)*kernel(ker,X,Xt);
tmp = tmp+b;
Yd = sign(tmp);
case 'svc_nu',
tmp = (a.*Y)*kernel(ker,X,X(:,i_sv));
b = Y(i_sv)-tmp;
b = mean(b);
tmp = (a.*Y)*kernel(ker,X,Xt);
Yd = sign(tmp+b);
case 'svm_one_class',
n_sv = length(i_sv);
tmp1 = zeros(n_sv,1);
for i = 1:n_sv
index = i_sv(i);
tmp1(i) = kernel(ker,X(:,index),X(:,index));
end
tmp2 = 2*a*kernel(ker,X,X(:,i_sv));
tmp3 = sum(sum(a'*a.*kernel(ker,X,X)));
R_square = tmp1-tmp2'+tmp3;
R_square = mean(R_square);
nt = size(Xt,2);
tmp4 = zeros(nt,1);
for i = 1:nt
tmp4(i) = kernel(ker,Xt(:,i),Xt(:,i));
end
tmp5 = 2*a*kernel(ker,X,Xt);
Yd = sign(tmp4-tmp5'+tmp3-R_square);
case 'svr_epsilon',
tmp = a*kernel(ker,X,X(:,i_sv));
b = Y(i_sv)-tmp;
%b = Y(i_sv)+tmp;
b = mean(b);
tmp = a*kernel(ker,X,Xt);
%tmp = -a*kernel(ker,X,Xt);
Yd = (tmp+b);
case 'svr_nu',
tmp = a*kernel(ker,X,X(:,i_sv));
b = Y(i_sv)-tmp;
%b = Y(i_sv)+tmp;
b = mean(b);
tmp = a*kernel(ker,X,Xt);
%tmp = -a*kernel(ker,X,Xt);
Yd = (tmp+b);
otherwise,
end
function [K] = kernel(ker,x,y)
% Calculate kernel function.
switch ker.type
case 'linear'
K = x'*y;
case 'ploy'
d = ker.degree;
c = ker.offset;
K = (x'*y+c).^d;
case 'gauss'
s = ker.width;
rows = size(x,2);
cols = size(y,2);
tmp = zeros(rows,cols);
for i = 1:rows
for j = 1:cols
tmp(i,j) = norm(x(:,i)-y(:,j));
end
end
K = exp(-0.5*(tmp/s).^2);
case 'tanh'
g = ker.gamma;
c = ker.offset;
K = tanh(g*x'*y+c);
otherwise
K = 0;
end
3. 仿真结果
4. 小结
SVM 可扩展到多分类问题,对于每个类,有一个当前类和其他类的二类分类器(one-vs-rest)将多分类问题转化为 n 个二分类问题,n 就是类别个数。
SVM 算法特性:
(1)训练好的模型的算法复杂度是由支持向量的个数决定的,而不是由数据的维度决定的。所以 SVM 不太容易产生 overfitting。
(2)SVM 训练出来的模型完全依赖于支持向量,即使训练集里面所有非支持向量的点都被去除,重复训练过程,结果仍然会得到完全一样的模型。
(3)一个 SVM 如果训练得出的支持向量个数比较少,那么SVM 训练出的模型比较容易被泛化。
SVM算法是机器视觉与信号处理中非常经典的一个算法,在面试的时候如果可以手撕SVM,将会得到很大的成功机会,在研究生期间曾应用SVM算法作为车牌字符识别的关键算法,得到较好的字符识别效果,后期有机会更新一篇具体实现的文章,有兴趣的推荐去仔细查看全文《机器学习20讲》中第九讲内容,以及网上关于SVM讲解比较好的课程,深入学习了解。