zoukankan      html  css  js  c++  java
  • sequential minimal optimization,SMO for SVM, (MATLAB code)

    function model = SMOforSVM(X, y, C )
    %sequential minimal optimization,SMO
    
    tol = 0.001; maxIters = 3000;
    
    global i1 i2 K Alpha  M1 m1 w b
    
    [m, n] = size(X);
    
    K = (X*X');
    
    Alpha = zeros(m,1); w = 0; b = 0;
    flag =1;iters = 1;
    while flag >0 & iters < maxIters 
        [i1,i2,m1,M1] = selectWorkSet(y, C);
        if m1 - M1 <= tol
            break;
        end
        solveOptimization(X, y, C)
        iters = iters +1;
    end
    
    model.alpha = Alpha;
    
    id = find(Alpha < C & Alpha >0);
    % b = mean(y(id)' - (y.*Alpha)'*K(:, id));
    
    id = id(1);
    b = y(id)' - (y.*Alpha)'*K(:, id);
    
    w= (y.*Alpha)'* X;
    model.w = w;
    model.b = b;
    end
    
    %Selecting working set B
    function [i1,i2,m1,M1]=selectWorkSet(y, C)
    global  K Alpha
    
    I_up =find ((Alpha < C & y == 1) | (Alpha > 0 &  y == -1));
    I_low = find( (Alpha < C & y == -1) | (Alpha > 0 &  y == 1));
    yGradient = - y.* (((y * y').* K) * Alpha - 1);
    
    [m1 , i1] = max(yGradient(I_up)); 
    [M1 , i2] = min(yGradient(I_low)); 
    
    i1 = I_up (i1);
    i2 = I_low(i2);
    
    end
    
    
    %Solving the two-variables optimization problem
    function solveOptimization(X, y, C)
    global Alpha K i1 i2 E
    alpha1_old = Alpha(i1);
    alpha2_old = Alpha(i2);
    y1 = y(i1);
    y2 = y(i2);
    
    % x1 = X(i1,:)';
    % x2 = X(i2,:)';
    beta11 = K(i1,i1); beta22 = K(i2,i2); beta12 = K(i1,i2);
    id  =[1: length(Alpha)];
    id([i1 i2]) = [];
    beta1 = sum( y(id).*Alpha(id).*K(id,i1));
    beta2 = sum( y(id).*Alpha(id).*K(id,i2));
    
    E = beta1 - beta2 + alpha1_old * y1 * (beta11 - beta12) +alpha2_old*y2 * (beta12 - beta22) - y1 + y2;
    kk = beta11 + beta22 - 2 * beta12;
    alpha2_new_unc = alpha2_old + (y2 * E)/kk;
    
    if y1 ~= y2
        L = max([0 , alpha2_old - alpha1_old]);
        H = min([C, C - alpha1_old + alpha2_old]);
    else
        L = max([0 , alpha1_old + alpha2_old - C]);
        H = min([C,  alpha1_old + alpha2_old]);
    end
    
    if  alpha2_new_unc > H
        alpha2_new = H;
    elseif  alpha2_new_unc < L
        alpha2_new = L;
    else
        alpha2_new = alpha2_new_unc ;
    end
    
    alpha1_new =  alpha1_old + y1 * y2 * (alpha2_old - alpha2_new);
    
    Alpha(i1) = alpha1_new;
    Alpha(i2) = alpha2_new;
    
    % for i=1:length(E)
    %     E(i) = sum(y .* Alphas .* K(i,:)) - b - y(i);
    % end
    % 
    % 
    % E1 = E(i1);
    % E2 = E(i2);
    % 
    % b1 = E1 + y1 * (a1 - alph1) * K(i1,i1) + y2 * (a2 - alph2) * K(i1,i2) - b;
    % b2 = E2 + y1 * (a1 - alph1) * K(i1,i2) + y2 * (a2 - alph2) * K(i2,i2) - b;
    % 
    % if b1 == b2
    %     b = b1;
    % else
    %     b = mean([b1 b2]);
    % end
    
    % w = w - y1 * (alpha1_new -alpha1_old) * X(i1,:)' - y2 * (alpha2_new -alpha2_old) * X(i2,:)';
    
    end
    

      

    clear
    X = []; Y=[];
    figure;
    % Initialize training data to empty; will get points from user
    % Obtain points froom the user:
    trainPoints=X;
    trainLabels=Y;
    clf;
    axis([-5 5 -5 5]);
    if isempty(trainPoints)
        % Define the symbols and colors we'll use in the plots later
        symbols = {'o','x'};
        classvals = [-1 1];
        trainLabels=[];
        hold on; % Allow for overwriting existing plots
        xlim([-5 5]); ylim([-5 5]);
        
        for c = 1:2
            title(sprintf('Click to create points from class %d. Press enter when finished.', c));
            [x, y] = getpts;
            
            plot(x,y,symbols{c},'LineWidth', 2, 'Color', 'black');
            
            % Grow the data and label matrices
            trainPoints = vertcat(trainPoints, [x y]);
            trainLabels = vertcat(trainLabels, repmat(classvals(c), numel(x), 1));        
        end
    
    end
    
    
    C = 10;
    par = SMOforSVM(trainPoints, trainLabels , C );
    p=length(par.b); m=size(trainPoints,2);
     if m==2
    %     for i=1:p
    %         plot(X(lc(i)-l(i)+1:lc(i),1),X(lc(i)-l(i)+1:lc(i),2),'bo')
    %         hold on
    %     end
        k = -par.w(1)/par.w(2);
        b0 = - par.b/par.w(2);
        bdown=(-par.b-1)/par.w(2);
        bup=(-par.b+1)/par.w(2);
        for i=1:p
            hold on
            h = refline(k,b0(i)); 
            set(h, 'Color', 'r') 
            hdown=refline(k,bdown(i));
            set(hdown, 'Color', 'b') 
            hup=refline(k,bup(i));
            set(hup, 'Color', 'b') 
        end  
     end
    xlim([-5 5]); ylim([-5 5]);

     以上代码结果写的比较粗糙,可能不稳定,我重新贴了一个新的代码:

    http://www.cnblogs.com/huadongw/p/4994657.html

  • 相关阅读:
    set转成toarray()
    list和set的拉拉扯扯的关系
    【转载】VNC和远程桌面的区别
    笔记本最小安装centos7 连接WiFi的方法
    mysql 索引优化 性能调优 锁
    PageHelper 自动去掉排序参数问题
    抽奖算法 百万次抽奖 单线程环境下 约 3.5 秒
    gitlab 安装和使用
    sharding sphere 分表分库 读写分离
    mycat 安装 分表 分库 读写分离
  • 原文地址:https://www.cnblogs.com/huadongw/p/4601377.html
Copyright © 2011-2022 走看看