首页 > 代码库 > matlab利用hinge loss实现多分类SVM

matlab利用hinge loss实现多分类SVM

  • 介绍
  • hinge loss
  • code

1 介绍

本文将介绍hinge loss E(w)<script type="math/tex" id="MathJax-Element-2310">E(w)</script>以及其梯度?E(w)<script type="math/tex" id="MathJax-Element-2311">\nabla E(w)</script>。并利用批量梯度下降方法来优化hinge loss实现SVM多分类。利用hinge loss在手写字数据库上实验,能达到87.040%的正确识别率。


2. hinge loss

  1. 根据二分类的SVM目标函数,我们可以定义多分类的SVM目标函数:
    E(w1,,wk)=kj=112||wj||2+Cni=1L((w1,,wk),(xi,yi))<script type="math/tex" id="MathJax-Element-2276">E(w_1,\ldots,w_k)=\sum_{j=1}^k\frac{1}{2}||w_j||^2+C\sum_{i=1}^{n}L((w_1,\ldots,w_k),(x_i,y_i))</script>.

其中T={(x1,y1),,(xn,yn)}<script type="math/tex" id="MathJax-Element-2277">T=\{(x_1,y_1),\ldots,(x_n,y_n)\}</script>为训练集。L((w1,,wk),(x,y))=max(0,maxyywTyx+1?wTyx)<script type="math/tex" id="MathJax-Element-2278">L((w_1,\ldots,w_k),(x,y))=max(0, \max_{y^{‘}\neq y}w_{y^{‘}}^{T}x+1-w_{y}^{T}x)</script>. 二分类SVM转化为多分类SVM的相关资料和公式推导可以参见其他文献。
2. 接下介绍E(w)<script type="math/tex" id="MathJax-Element-2279">E(w)</script>的梯度计算。
(a) 如果 wTywTy^x+1<script type="math/tex" id="MathJax-Element-2280">w_y^{T}\geq w_{\hat{y}}^{T}x+1</script>, 那么

?L((w1,w2,,wk),(x,y))?wj,l=0<script type="math/tex" id="MathJax-Element-2281">\frac{\partial L((w_1,w_2,\ldots, w_k),(x,y))}{\partial w_{j,l}}=0</script>

(b) 如果 wTy<wTy^x+1<script type="math/tex" id="MathJax-Element-2282">w_y^{T} < w_{\hat{y}}^{T}x+1</script> 和 j=y<script type="math/tex" id="MathJax-Element-2283">j=y</script>, 那么

?L((w1,w2,,wk),(x,y))?wj,l=?xl<script type="math/tex" id="MathJax-Element-2284">\frac{\partial L((w_1,w_2,\ldots, w_k),(x,y))}{\partial w_{j,l}}=-x_{l}</script>

(c) 如果 wTy<wTy^x+1<script type="math/tex" id="MathJax-Element-2285">w_y^{T}< w_{\hat{y}}^{T}x+1</script> 和 j=y^<script type="math/tex" id="MathJax-Element-2286">j=\hat{y}</script>, 那么

?L((w1,w2,,wk),(x,y))?wj,l=xl<script type="math/tex" id="MathJax-Element-2287">\frac{\partial L((w_1,w_2,\ldots, w_k),(x,y))}{\partial w_{j,l}}=x_l</script>

(d) 如果 wTy<wTy^x+1<script type="math/tex" id="MathJax-Element-2288">w_y^{T}< w_{\hat{y}}^{T}x+1</script> 和 jy<script type="math/tex" id="MathJax-Element-2289">j\neq y</script> and jy^<script type="math/tex" id="MathJax-Element-2290">j\neq \hat{y}</script>, 那么

?L((w1,w2,,wk),(x,y))?wj,l=0<script type="math/tex" id="MathJax-Element-2291">\frac{\partial L((w_1,w_2,\ldots, w_k),(x,y))}{\partial w_{j,l}}=0</script>

  1. 利用梯度下降法更新W={w1,,wk}<script type="math/tex" id="MathJax-Element-2292">W=\{w_1,\ldots,w_k\}</script>:
    Wt=Wt?1?r?E(Wt?1)<script type="math/tex" id="MathJax-Element-2293">W_t=W_{t-1}-r\nabla E(W_{t-1})</script>。

3 code

Muliticlass_svm.m

% 作者:何凌霄
% 中科院自动化所
% 2017315
clear all
clc
%% STEP 0: Initialise constants and parameters
inputSize = 28 * 28; % Size of input vector (MNIST images are 28x28)
numClasses = 10;     % Number of classes (MNIST images fall into 10 classes)
lambda = 1e-2; % Weight decay parameter
learning_rate = 0.1;
iteration=400;
%%======================================================================
%% STEP 1: Load data
load(‘digits.mat‘)
images = [train1; train2; train3; train4; train5; train6; train7; train8; train9;train0];
images = images‘;
labels = [ones(500,1);2*ones(500,1);3*ones(500,1);4*ones(500,1);5*ones(500,1);6*ones(500,1);7*ones(500,1);8*ones(500,1);9*ones(500,1);10*ones(500,1)];
index = randperm(500*10);
images = images(:,index);
labels = labels(index);
inputData = http://www.mamicode.com/images;>;
labels = [ones(500,1);2*ones(500,1);3*ones(500,1);4*ones(500,1);5*ones(500,1);6*ones(500,1);7*ones(500,1);8*ones(500,1);9*ones(500,1);10*ones(500,1)];

inputData = http://www.mamicode.com/images;"hljs-constant">You will have to implement softmaxPredict in softmaxPredict.m
[pred] = Multi_SVMPredict(svmModel, inputData);
acc = mean(labels(:) == pred(:));
num_in_class = 500*ones(10,1)‘;
for i=1:10
    name_class{i}=num2str(i);
end
[confusion_matrix]=compute_confusion_matrix(pred,num_in_class,name_class);
figure; visualize(svmOptTheta‘);
fprintf(‘Accuracy: %0.3f%%\n‘, acc * 100);

multisvmtrain.m

% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [lcost, grad, theta] = multisvmtrain(numClasses, inputSize, lambda, data, labels, iteration, learning_rate)
theta = 0.005 * randn(numClasses * inputSize, 1);
theta = reshape(theta, numClasses, inputSize);%将输入的参数列向量变成一个矩阵
numCases = size(data, 2);%输入样本的个数
groundTruth = full(sparse(labels, 1:numCases, 1));%这里sparse是生成一个稀疏矩阵,该矩阵中的值都是第三个值1
cost = 0;
thetagrad = zeros(numClasses, inputSize);
for i = 1:iteration
    [Q, X, cost] = multi_hingeloss_cost(theta, data, groundTruth,lambda);
    [thetagrad] = multi_hingeloss_grad(data,theta, Q, groundTruth, lambda, labels);
    theta = theta - learning_rate*thetagrad;
    lcost(i) = cost;
    grad(i) = sum(sum(thetagrad));
    fprintf(‘%d, %f\n‘, i, cost);
end
end

multi_hingeloss_cost.m

% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [Q, X, cost] = multi_hingeloss_cost(theta, data, groundTruth,lambda)
groundTruth1 = groundTruth;
groundTruth(find(groundTruth==1)) = -inf;  
groundTruth(find(groundTruth==0)) = 1; 
X = theta*data;
Q = X;
Q = Q.*groundTruth;
Q(find(Q==inf)) = -inf;
temp = X.*groundTruth1;
temp(find(temp==0))=[];
t = max(0, 1 - temp + max(Q));
cost = 1/size(data,2)*sum(t)+lambda*sum(theta(:).^2);

multi_hingeloss_grad.m

% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [thetagrad] = multi_hingeloss_grad(data, theta, Q, groundTruth, lambda, labels)
X = theta*data;
[~,q] = max(Q);
Xq = full(sparse(q, 1:size(X,2), 1));
if size(Xq,1)<10
    for i = 1:10-size(Xq,1)
        Xq = [Xq;zeros(1, size(Xq,2))];
    end
end
temp = X.*groundTruth;
temp1 = X.*Xq;
temp1(find(temp1==0))=[];
temp(find(temp==0))=[];
W=(temp - temp1)<1;
Y = zeros(size(X));

for i=1:size(X,2)
    Y(labels(i),i) = -W(i);
    Y(q(i),i) = W(i);
end
thetagrad = 1/size(X,2)*Y*data‘ + lambda * theta;

Multi_SVMPredict.m

% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [pred] = Multi_SVMPredict(svmModel, data)
theta = svmModel.optTheta;  % this provides a numClasses x inputSize matrix
pred = zeros(1, size(data, 2));
[nop, pred] = max(theta * data);
end

compute_confusion_matrix.m

[confusion_matrix]=compute_confusion_matrix(predict_label,num_in_class,name_class)%预测标签,每一类的数目,类别数目  
%predict_label为一维行向量  
%num_in_class代表每一类的个数  
%name_class代表类名  
num_class=length(num_in_class);  
num_in_class=[0 num_in_class];  
confusion_matrix=size(num_class,num_class);  

for ci=1:num_class  
    for cj=1:num_class  
        summer=0;%统计对应标签个数  
        c_start=sum(num_in_class(1:ci))+1;  
        c_end=sum(num_in_class(1:ci+1));  
        summer=size(find(predict_label(c_start:c_end)==cj),2);  
        confusion_matrix(ci,cj)=summer/num_in_class(ci+1);  
    end  
end  

draw_cm(confusion_matrix,name_class,num_class);  

end  

function draw_cm.m

function draw_cm(mat,tick,num_class)  

imagesc(1:num_class,1:num_class,mat);            %# in color  
colormap(flipud(gray));  %# for gray; black for large value.  

textStrings = num2str(mat(:),‘%0.2f‘);    
textStrings = strtrim(cellstr(textStrings));   
[x,y] = meshgrid(1:num_class);   
hStrings = text(x(:),y(:),textStrings(:), ‘HorizontalAlignment‘,‘center‘);  
midValue = http://www.mamicode.com/mean(get(gca,‘CLim‘));   
textColors = repmat(mat(:) > midValue,1,3);   
set(hStrings,{‘Color‘},num2cell(textColors,2));  %# Change the text colors  

set(gca,‘xticklabel‘,tick,‘XAxisLocation‘,‘top‘);  
set(gca, ‘XTick‘, 1:num_class, ‘YTick‘, 1:num_class);  
set(gca,‘yticklabel‘,tick);  
rotateXLabels(gca, 315 );% rotate the x tick  

visualize.m

function r=visualize(X, mm, s1, s2)
%FROM RBMLIB http://code.google.com/p/matrbm/
%Visualize weights X. If the function is called as a void method,
%it does the plotting. But if the function is assigned to a variable 
%outside of this code, the formed image is returned instead.
if ~exist(‘mm‘,‘var‘)
    mm = [min(X(:)) max(X(:))];
end
if ~exist(‘s1‘,‘var‘)
    s1 = 0;
end
if ~exist(‘s2‘,‘var‘)
    s2 = 0;
end

[D,N]= size(X);
s=sqrt(D);
if s==floor(s) || (s1 ~=0 && s2 ~=0)
    if (s1 ==0 || s2 ==0)
        s1 = s; s2 = s;
    end
    %its a square, so data is probably an image
    num=ceil(sqrt(N));
    a=mm(2)*ones(num*s2+num-1,num*s1+num-1);
    x=0;
    y=0;
    for i=1:N
        im = reshape(X(:,i),s1,s2)‘;
        a(x*s2+1+x : x*s2+s2+x, y*s1+1+y : y*s1+s1+y)=im;
        x=x+1;
        if(x>=num)
            x=0;
            y=y+1;
        end
    end
    d=true;
else
    %there is not much we can do
    a=X;
end

%return the image, or plot the image
if nargout==1
    r=a;
else

    imagesc(a, [mm(1) mm(2)]);
    axis equal
    colormap gray

end

得到的识别率为87.040%,hinge loss可以和任何深度网络结合完成分类任务。
最后得到的混淆矩阵如下:
技术分享

损失函数图像:
技术分享

数据集见资源,如引用此代码,请注明出处。

<script type="text/javascript"> $(function () { $(‘pre.prettyprint code‘).each(function () { var lines = $(this).text().split(‘\n‘).length; var $numbering = $(‘
    ‘).addClass(‘pre-numbering‘).hide(); $(this).addClass(‘has-numbering‘).parent().append($numbering); for (i = 1; i <= lines; i++) { $numbering.append($(‘
  • ‘).text(i)); }; $numbering.fadeIn(1700); }); }); </script>

    matlab利用hinge loss实现多分类SVM