zoukankan      html  css  js  c++  java
  • SVM回归

    SVM回归任务是限制间隔违规情况下,尽量防止更多的样本在“街道”上。“街道”的宽度由超参数(epsilon)控制
    在随机生成的线性数据上,两个线性SVM回归模型,一个有较大的间隔((epsilon=1.5)),另一个间隔较小((epsilon=0.5)),训练情况如下:
    代码如下:
    造数据与训练:

    np.random.seed(42)
    m = 50
    X = 2 * np.random.randn(m,1)
    y = (4 + 3 * X + np.random.randn(m,1)).ravel()
    
    from sklearn.svm import LinearSVR
    
    svm_reg1 = LinearSVR(epsilon=1.5, random_state=42)
    svm_reg2 = LinearSVR(epsilon=0.5, random_state=42)
    
    svm_reg1.fit(X, y)
    svm_reg2.fit(X,y)
    

    可视化编码

    def find_support_vectors(svm_reg, X, y):
        y_pred = svm_reg.predict(X)
        off_margin = (np.abs(y - y_pred) >= svm_reg.epsilon)
        return np.argwhere(off_margin)
    
    svm_reg1.support_ = find_support_vectors(svm_reg1, X, y)
    svm_reg2.support_ = find_support_vectors(svm_reg2, X, y)
    
    eps_x1 = 1
    eps_y_pred = svm_reg1.predict([[eps_x1]])
    def plot_svm_regression(svm_reg, X, y, axes):
        x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)
        y_pred = svm_reg.predict(x1s)
        plt.plot(x1s, y_pred, "k-", linewidth=2, label=r"$hat{y}$")
        plt.plot(x1s, y_pred + svm_reg.epsilon, "k--")
        plt.plot(x1s, y_pred - svm_reg.epsilon, "k--")
        plt.scatter(X[svm_reg.support_], y[svm_reg.support_], s=180, facecolors='#FFAAAA')
        plt.plot(X, y, "bo")
        plt.xlabel(r"$x_1$", fontsize=18)
        plt.legend(loc="upper left", fontsize=18)
        plt.axis(axes)
    
    plt.figure(figsize=(9, 4))
    plt.subplot(121)
    plot_svm_regression(svm_reg1, X, y, [0, 2, 3, 11])
    plt.title(r"$epsilon = {}$".format(svm_reg1.epsilon), fontsize=18)
    plt.ylabel(r"$y$", fontsize=18, rotation=0)
    #plt.plot([eps_x1, eps_x1], [eps_y_pred, eps_y_pred - svm_reg1.epsilon], "k-", linewidth=2)
    plt.annotate(
            '', xy=(eps_x1, eps_y_pred), xycoords='data',
            xytext=(eps_x1, eps_y_pred - svm_reg1.epsilon),
            textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}
        )
    plt.text(0.91, 5.6, r"$epsilon$", fontsize=20)
    plt.subplot(122)
    plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])
    plt.title(r"$epsilon = {}$".format(svm_reg2.epsilon), fontsize=18)
    
    plt.show()
    

    可视化展示:

    非线性拟合

    造数据

    np.random.seed(42)
    m = 100
    X = 2 * np.random.rand(m, 1) - 1
    y = (0.2 + 0.1 * X + 0.5 * X**2 + np.random.randn(m, 1)/10).ravel()
    
    from sklearn.svm import SVR
    
    from sklearn.svm import SVR
    
    svm_poly_reg1 = SVR(kernel="poly", degree=2, C=100, epsilon=0.1, gamma="auto")
    svm_poly_reg2 = SVR(kernel="poly", degree=2, C=0.01, epsilon=0.1, gamma="auto")
    svm_poly_reg1.fit(X, y)
    svm_poly_reg2.fit(X, y)
    

    可视化编程

    plt.figure(figsize=(9, 4))
    plt.subplot(121)
    plot_svm_regression(svm_poly_reg1, X, y, [-1, 1, 0, 1])
    plt.title(r"$degree={}, C={}, epsilon = {}$".format(svm_poly_reg1.degree, svm_poly_reg1.C, svm_poly_reg1.epsilon), fontsize=18)
    plt.ylabel(r"$y$", fontsize=18, rotation=0)
    plt.subplot(122)
    plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])
    plt.title(r"$degree={}, C={}, epsilon = {}$".format(svm_poly_reg2.degree, svm_poly_reg2.C, svm_poly_reg2.epsilon), fontsize=18)
    
    plt.show()
    

    可视化展示:

  • 相关阅读:
    bzoj2733 永无乡 平衡树按秩合并
    bzoj2752 高速公路 线段树
    bzoj1052 覆盖问题 二分答案 dfs
    bzoj1584 打扫卫生 dp
    bzoj1854 游戏 二分图
    bzoj3316 JC loves Mkk 二分答案 单调队列
    bzoj3643 Phi的反函数 数学 搜索
    有一种恐怖,叫大爆搜
    BZOJ3566 概率充电器 概率dp
    一些奇奇怪怪的过题思路
  • 原文地址:https://www.cnblogs.com/whiteBear/p/13096398.html
Copyright © 2011-2022 走看看