zoukankan      html  css  js  c++  java
  • SVM回归

    SVM回归任务是限制间隔违规情况下,尽量防止更多的样本在“街道”上。“街道”的宽度由超参数(epsilon)控制
    在随机生成的线性数据上,两个线性SVM回归模型,一个有较大的间隔((epsilon=1.5)),另一个间隔较小((epsilon=0.5)),训练情况如下:
    代码如下:
    造数据与训练:

    np.random.seed(42)
    m = 50
    X = 2 * np.random.randn(m,1)
    y = (4 + 3 * X + np.random.randn(m,1)).ravel()
    
    from sklearn.svm import LinearSVR
    
    svm_reg1 = LinearSVR(epsilon=1.5, random_state=42)
    svm_reg2 = LinearSVR(epsilon=0.5, random_state=42)
    
    svm_reg1.fit(X, y)
    svm_reg2.fit(X,y)
    

    可视化编码

    def find_support_vectors(svm_reg, X, y):
        y_pred = svm_reg.predict(X)
        off_margin = (np.abs(y - y_pred) >= svm_reg.epsilon)
        return np.argwhere(off_margin)
    
    svm_reg1.support_ = find_support_vectors(svm_reg1, X, y)
    svm_reg2.support_ = find_support_vectors(svm_reg2, X, y)
    
    eps_x1 = 1
    eps_y_pred = svm_reg1.predict([[eps_x1]])
    def plot_svm_regression(svm_reg, X, y, axes):
        x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)
        y_pred = svm_reg.predict(x1s)
        plt.plot(x1s, y_pred, "k-", linewidth=2, label=r"$hat{y}$")
        plt.plot(x1s, y_pred + svm_reg.epsilon, "k--")
        plt.plot(x1s, y_pred - svm_reg.epsilon, "k--")
        plt.scatter(X[svm_reg.support_], y[svm_reg.support_], s=180, facecolors='#FFAAAA')
        plt.plot(X, y, "bo")
        plt.xlabel(r"$x_1$", fontsize=18)
        plt.legend(loc="upper left", fontsize=18)
        plt.axis(axes)
    
    plt.figure(figsize=(9, 4))
    plt.subplot(121)
    plot_svm_regression(svm_reg1, X, y, [0, 2, 3, 11])
    plt.title(r"$epsilon = {}$".format(svm_reg1.epsilon), fontsize=18)
    plt.ylabel(r"$y$", fontsize=18, rotation=0)
    #plt.plot([eps_x1, eps_x1], [eps_y_pred, eps_y_pred - svm_reg1.epsilon], "k-", linewidth=2)
    plt.annotate(
            '', xy=(eps_x1, eps_y_pred), xycoords='data',
            xytext=(eps_x1, eps_y_pred - svm_reg1.epsilon),
            textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}
        )
    plt.text(0.91, 5.6, r"$epsilon$", fontsize=20)
    plt.subplot(122)
    plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])
    plt.title(r"$epsilon = {}$".format(svm_reg2.epsilon), fontsize=18)
    
    plt.show()
    

    可视化展示:

    非线性拟合

    造数据

    np.random.seed(42)
    m = 100
    X = 2 * np.random.rand(m, 1) - 1
    y = (0.2 + 0.1 * X + 0.5 * X**2 + np.random.randn(m, 1)/10).ravel()
    
    from sklearn.svm import SVR
    
    from sklearn.svm import SVR
    
    svm_poly_reg1 = SVR(kernel="poly", degree=2, C=100, epsilon=0.1, gamma="auto")
    svm_poly_reg2 = SVR(kernel="poly", degree=2, C=0.01, epsilon=0.1, gamma="auto")
    svm_poly_reg1.fit(X, y)
    svm_poly_reg2.fit(X, y)
    

    可视化编程

    plt.figure(figsize=(9, 4))
    plt.subplot(121)
    plot_svm_regression(svm_poly_reg1, X, y, [-1, 1, 0, 1])
    plt.title(r"$degree={}, C={}, epsilon = {}$".format(svm_poly_reg1.degree, svm_poly_reg1.C, svm_poly_reg1.epsilon), fontsize=18)
    plt.ylabel(r"$y$", fontsize=18, rotation=0)
    plt.subplot(122)
    plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])
    plt.title(r"$degree={}, C={}, epsilon = {}$".format(svm_poly_reg2.degree, svm_poly_reg2.C, svm_poly_reg2.epsilon), fontsize=18)
    
    plt.show()
    

    可视化展示:

  • 相关阅读:
    kmp 算法
    jdk 和 cglib 的动态代理
    RestTemplate工具类
    bat脚本切换多个工程的分支
    字符串的左旋转
    输入一个正数s,打印出所有和为s的连续正数序列(至少含有两个数)。例如输入15,由于1+2+3+4+5=4+5+6=7+8=15,所以结果打印出3个连续序列1~5、4~6和7~8。
    枚举类型在JPA中的使用
    拾遗
    YAML DEMO
    kiali 1.26 anonymous策略修改为token
  • 原文地址:https://www.cnblogs.com/whiteBear/p/13096398.html
Copyright © 2011-2022 走看看