zoukankan      html  css  js  c++  java
  • 随机梯度下降算法求解SVM

    测试代码(matlab)如下:

    clear;
    load E:datasetUSPSUSPS.mat;
    % data format:
    % Xtr n1*dim
    % Xte n2*dim
    % Ytr n1*1
    % Yte n2*1
    % warning: labels must range from 1 to n, n is the number of labels
    % other label values will make mistakes
    u=unique(Ytr);
    Nclass=length(u);

    allw=[];allb=[];
    step=0.01;C=0.1;
    param.iterations=1;
    param.lambda=1e-3;
    param.biaScale=1;
    param.t0=100;

    tic;
    for classname=1:1:Nclass
    temp_Ytr=change_label(Ytr,classname);
    [w,b] = sgd_svm(Xtr,temp_Ytr, param);
    allw=[allw;w];
    allb=[allb;b];
    fprintf('class %d is done ', classname);
    end

    [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb);
    fprintf(' accuracy is %.2f percent. ' , accuracy*100 );
    toc;

    function [temp_Ytr] = change_label(Ytr,classname)
    temp_Ytr=Ytr;
    tep2=find(Ytr~=classname);
    tep1=find(Ytr==classname);
    temp_Ytr(tep2)=-1;
    temp_Ytr(tep1)= 1;


    function [true_W,b]=sgd_svm(X,Y,param)
    % input:
    % X is n*dim
    % Y is n*1 (label is 1 or 0)
    % output:
    % true_W is dim*1 ,so the score is X*W'+b
    % b is 1*1 number
    iterations=param.iterations;%10
    lambda=param.lambda;%1e-3
    biaScale=param.biaScale;%0
    t0=param.t0;%100
    t=t0;

    w=zeros(1,size(X,2));
    bias=0;

    for k=1:1:iterations
    for i=1:1:size(X,1)
    t=t+1;
    alpha = (1.0/(lambda*t));
    if(Y(i)*(X(i,:)*w'+bias)<1)
    bias=bias+alpha*Y(i)*biaScale;
    w=w+alpha*Y(i,1).*X(i,:);
    end
    end
    end
    b=bias;
    true_W=w;

    function [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb)
    % allw is nclass * dim
    % allb is nclass * 1
    % Yte must range from 1 to nclass, other label values will make mistakes
    score = Xte * allw'+repmat(allb',[size(Bte,1),1]);
    [bb c]=sort(score,2,'descend');
    predict_label=c(:,1);
    temp = predict_label((predict_label-Yte)==0);
    right=size( temp,1 );
    accuracy=right/size(Yte,1);

  • 相关阅读:
    QQ空间鼠标代码
    QQ空间Flash
    QQ播放器代码
    QQ空间鼠标代码
    QQ空间Flash
    QQ空间Flash
    第二届“携进杯”师生羽毛球联谊赛
    DataView对象
    数据控件DataGrid数据控件
    数据控件Repeater数据控件
  • 原文地址:https://www.cnblogs.com/zhouxiaohui888/p/6077994.html
Copyright © 2011-2022 走看看