zoukankan      html  css  js  c++  java
  • SVM 之 MATLAB 实现代码

    MATLAB 中 SVM 实现

    直接上代码

    • main.m
    
    %% Initialize data
    clear, clc, close all;
    load('data.mat');
    
    y(y == 0) = -1;
    % X_train = X(1:35, :);
    % y_train = y(1:35);
    % X_test = X(36:51, :);
    % y_test = y(36:51);
    
    %% Visualize data
    % jhplotdata(X_train, y_train);
    
    %% Training a SVM(Support Vector Machine) Classifier
    C = 10;
    svm = jhsvmtrain(X, y, C, 'Linear');
    result = jhsvmtest(svm, X);
    fprintf('Accuracy: %f
    ', mean(double(result.y_pred == y)));
    jhplotdata(X, y);
    hold on;
    x1_min = min(X(:, 1)) - 1;
    x1_max = max(X(:, 1)) + 1;
    x2_min = min(X(:, 2)) - 1;
    x2_max = max(X(:, 2)) + 1;
    
    [XX, YY] = meshgrid(x1_min:0.02:x1_max, x2_min:0.02:x2_max);
    Z = jhsvmtest(svm, [XX(:) YY(:)]);
    Z = reshape(Z.y_pred, size(XX));
    contour(XX, YY, Z);
    hold off;
    
    • jhsvmtrain.m
    
    function [model] = jhsvmtrain(X, y, C, kernel_type)
    %% 函数的核心就是对拉格朗日对偶式的二次规划问题, 通过返回的alpha得到我们需要的支持向量
    
    % convert the primal problem to a dual problem, the dual problem is written
    % below.
    
    % the number of training examples.
    m = length(y);
    
    % NOTE!! The following two statements, which represent the 
    % target function, are fixed cause our target function is fixed.
    H = y * y' * jhkernels(X', X', kernel_type);
    f = -ones(m, 1);
    A = [];
    b = [];
    Aeq = y';
    beq = 0;
    lb = zeros(m, 1);
    % C is the regularization parameter which means that our model has already
    % taken the error into the consideration.
    ub = C * ones(m, 1);
    alphas0 = zeros(m, 1);
    
    epsilon = 1e-8;
    options = optimset('LargeScale', 'off', 'Display', 'off');
    alphas1 = quadprog(H, f, A, b, Aeq, beq, lb, ub, alphas0, options);
    
    logic_vector = abs(alphas1) > epsilon;
    model.vec_x = X(logic_vector, :);
    model.vec_y = y(logic_vector);
    model.alphas = alphas1(logic_vector);
    
    end
    
    • jhsvmtest.m
    
    function result = jhsvmtest(model, X)
    % 在svmTrain中我们主要计算的就是那几个支持向量, 对应的, 核心就是alpha
    % 现在有了alpha, 我们通过公式可以轻而易举地计算出w, 我们还不知道b的值, 也即是超平面偏差的值
    % 所有先将我们的支持向量代入到公式中, 计算出一个临时的w
    % 对于一直的支持向量来说, 我们已经知道了它的label, 所有可以计算出b, 将超平面拉过来, 再将这个b运用到测试集中即可
    
    % 带入公式w = sum_{i=1}^{m}alpha^{(i)}y^{(i)}x^{(i)}^Tx
    % x是输入需要预测的值
    tmp = (model.alphas' .* model.vec_y' * jhkernels(model.vec_x', model.vec_x', 'Linear'))';
    % 计算出偏差, 也就是超平面的截距
    total_bias = model.vec_y - tmp;
    bias = mean(total_bias);
    
    % 我们已经得到了apha, 因为w是由alpha表示的, 所以通过alpha可以计算出w
    % w = sum(alpha .* y_sv)*kernel(x_sv, x_test)
    % 其中y_sv是sv的标签, x_sv是sv的样本, x_test是需要预测的数据
    w = (model.alphas' .* model.vec_y' * jhkernels(model.vec_x', X', 'Linear'))';
    result.w = w;
    result.y_pred = sign(w + bias);
    result.b = bias;
    end
    
    • jhkernel.m
    
    function K = jhkernels(X1, X2, kernel_type)
    
    switch lower(kernel_type)
        
        case 'linear'
            K = X1' * X2;
        
        case 'rbf'
            K = X1' * X2;
            fprintf("I am sorry about that the rbg kernel is not implemented yet, here we still use the linear kernel to compute
    ");    
    end
    
    end
    
    • jhplotdata.m
    
    function jhplotdata(X, y, caption, labelx, labely, color1, color2)
    
    if ~exist('caption', 'var') || isempty(caption)
       caption = 'The relationship between X1 and X2';
    end
    
    if ~exist('labelx', 'var') || isempty(labelx)
       labelx = 'X1';
    end
    
    if ~exist('labely', 'var') || isempty(labely)
       labely = 'X2';
    end
    
    if ~exist('color1', 'var') || isempty(color1)
       color1 = 'r'; 
    end
    
    if ~exist('color2', 'var') || isempty(color2)
       color2 = 'r'; 
    end
    
    % JHPLOTDATA is going to plot two dimentional data
    positive = find(y == 1);
    negative = find(y == -1);
    
    plot(X(positive, 1), X(positive, 2), 'ro', 'MarkerFace', color1);
    hold on;
    
    plot(X(negative, 1), X(negative, 2), 'bo', 'MarkerFace', color2);
    title(caption);
    xlabel(labelx);
    ylabel(labely);
    legend('Positive Data', 'Negative Data');
    
    hold off;
    end
    
  • 相关阅读:
    MySql 学习之 一条更新sql的执行过程
    MySql 学习之 一条查询sql的执行过程
    VUE基本介绍
    ESMAScript6基本介绍
    npm
    tensorflow2.0 评估函数
    网页引入mathjax,latex
    Veno File Manager
    tensorflow 测量工具,与自定义训练
    tensorflow自定义网络结构
  • 原文地址:https://www.cnblogs.com/megachen/p/10024059.html
Copyright © 2011-2022 走看看