zoukankan      html  css  js  c++  java
  • SMO推导和代码-记录毕业论文4

    SMO的数学公式通过Platt的论文和看这个博客:http://www.cnblogs.com/jerrylead/archive/2011/03/18/1988419.html,大概弄懂了。推导以后再写,贴上一个自己写的SMO的代码。

    function [ model ] = smoSolver( designMatrix, targetGroup )
    numChanged = 0;
    examineAll = 1;
    tolerance = 0.001; total_runtimes = 5000; epsilon = 0.01;
    n_samps = size(designMatrix,1);
    kernelMatrix = zeros(n_samps, n_samps);
    for i = 1 : n_samps
        for j = i : n_samps
            kernelMatrix(i,j) = dot(designMatrix(i,:), designMatrix(j,:));
            kernelMatrix(j,i) =  kernelMatrix(i,j);
        end
    end
    alphaArray = rand(1, n_samps);
    C = 1; b = 0;
    u = alphaArray .* targetGroup * kernelMatrix - b;
    E = u - targetGroup;
    iter = 1 ;
    while(numChanged > 0 || examineAll)
        numChanged = 0;
        if(examineAll)
            for i = 1 : n_samps
                numChanged = numChanged + examineExample(i);
            end
        else
            for i = 1 : n_samps
                if abs(alphaArray(i)) > tolerance && abs(alphaArray(i)-C) > tolerance
                    numChanged = numChanged + examineExample(i);
                end
            end
        end
        if(examineAll == 1)
            examineAll = 0;
        elseif (numChanged == 0)
            examineAll = 1;
        end
        iter = iter + 1;
        if iter > total_runtimes
            break;
        end
    end
    
    function changed = examineExample(i)
        y2 = targetGroup(i);
        alpha2 = alphaArray(i);
        E2 = E(i);
        r2 = E2 * y2;
        %if((r2 < -0.01 && alpha2 < C) || (r2 > 0.01 && alpha2 > 0))
        if( (r2 < -tolerance && abs(alpha2) < tolerance) || ...
             (r2 > tolerance && abs(alpha2-C) < tolerance) || ...
             (abs(r2) > tolerance && alpha2 < C-tolerance && alpha2 > tolerance ) )
            non_zero_non_c = find(abs(alphaArray)>0.01 & abs(alphaArray-C)>0.01);
            if length(non_zero_non_c) > 1
                maxIdx = 1; max = 0;
                for idx = 1 : n_samps
                    if abs(E(idx) - E2) > max
                        max = abs(E(idx) - E2);
                        maxIdx = idx;
                    end
                end
                if takeStep(maxIdx, i)
                    changed = 1; return;
                end
            end
            
            for k = 1 : length(non_zero_non_c)
                i1 = non_zero_non_c(k);
                if takeStep(i1, i);
                    changed = 1; return;
                end
            end
            
            for k = 1 : n_samps
                if takeStep(k, i)
                    changed = 1; return;
                end
            end
        end
        changed = 0; return;
    end
    
    function tf = takeStep(i1, i2)
    if i1 == i2
        tf = 0; return;
    end
    
    alpha1 = alphaArray(i1); a1 = 0;
    alpha2 = alphaArray(i2); a2 = 0;
    y1 = targetGroup(i1); y2 = targetGroup(i2);
    E1 = E(i1); E2 = E(i2);
    s = y1 * y2;
    if s > 0
        L = max([0,alpha1+alpha2-C]);
        H = min([C,alpha1+alpha2]);
    else
        L = max([0,alpha2-alpha1]);
        H = min([C, C+alpha2-alpha1]);
    end
    
    if L == H
        tf = 0; return;
    end
    k11 = kernelMatrix(i1,i1);
    k12 = kernelMatrix(i1,i2);
    k22 = kernelMatrix(i2,i2);
    eta = k11 + k22 - 2*k12;
    if(eta > 0)
        a2 = alpha2 + y2 * (E1-E2)/eta; 
        if(a2 < L) 
            a2 = L;
        elseif (a2 > H)
            a2 = H;
        end
    else
        a2 = L;
        a1 = alpha1 + s*(alpha2-a2);
        alphaArrayTmp = alphaArray; alphaArrayTmp(i1) = a1; alphaArrayTmp(i2) = a2;
        alphaArrayTmp = alphaArrayTmp .* targetGroup;
        Lobj = 0.5 * alphaArrayTmp * kernelMatrix * alphaArrayTmp' - sum(alphaArrayTmp);
        
        a2 = H;
        a1 = alpha1 + s*(alpha2-a2);
        alphaArrayTmp = alphaArray; alphaArrayTmp(i1) = a1; alphaArrayTmp(i2) = a2;
        alphaArrayTmp = alphaArrayTmp .* targetGroup;
        Hobj = 0.5 * alphaArrayTmp * kernelMatrix * alphaArrayTmp' - sum(alphaArrayTmp);
        if(Lobj < Hobj - epsilon)
            a2 = L;
        elseif(Lobj > Hobj + epsilon)
            a2 = H;
        else
            a2 = alpha2;
        end
    end
    if (abs(a2-alpha2) < 0.01*(a2+alpha2+epsilon))
        tf = 0; return;
    end
    
    a1 = alpha1 + s*(alpha2-a2);
    
    b1 = E1 + y1*(a1 - alpha1)*kernelMatrix(i1,i1)+y2*(a2 - alpha2)*kernelMatrix(i1,i2)+b;
    b2 = E2 + y1*(a1 - alpha1)*kernelMatrix(i1,i2)+y2*(a2 - alpha2)*kernelMatrix(i2,i2)+b;
    if(a1 > 0 && a1 < C)
        b = b1;
    elseif(a2 > 0 && a2 < C)
        b = b2;
    else
        b = (b1+b2)/2;
    end
    alphaArray(i1) = a1; alphaArray(i2) = a2;
    
    u = alphaArray .* targetGroup * kernelMatrix - b;
    E = u - targetGroup;
    
    tf = 1; return;
    end
    
    u = alphaArray .* targetGroup * kernelMatrix - b;
    alphaIdx = find(abs(alphaArray) > tolerance); 
    model.targetGroup = targetGroup(alphaIdx);
    model.alpha = alphaArray(alphaIdx);
    model.supVec = designMatrix(alphaIdx, :);
    model.b = b;
    
    end
    

    smoPredict:

    function [ targetGroup ] = smoPredict( model, designMatrix )
    kernelMatrix = model.supVec * designMatrix';
    u = sum(kernelMatrix' .* model.alpha .* model.targetGroup) - model.b;
    targetGroup = sign(u);
    end
  • 相关阅读:
    css控制英文内容自动换行問題
    jquery添加select option两种代码思路比较
    C++实现单例模式
    C++实现单例模式
    windows下socket编程:区分shutdown()及closesocket()
    windows下socket编程:区分shutdown()及closesocket()
    socket关闭
    socket关闭
    C++模板的一些巧妙功能
    C++模板的一些巧妙功能
  • 原文地址:https://www.cnblogs.com/Key-Ky/p/5117185.html
Copyright © 2011-2022 走看看