zoukankan      html  css  js  c++  java
  • 支持向量机实现

    采用的测试数据:参考上一篇博客4.1部分

    https://www.cnblogs.com/hhjing/p/14340924.html

    1、

    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    #定义函数
    def linear_svm(X,y,lam,max_iter=2000):
        w=np.zeros(X.shape[1])#初始化
        support_vectors=[]#创建空列表保存支持向量
        
        for t in range(max_iter):#进行died
            learning_rate=1/(lam*(t+1))#计算本轮迭代的学习率
            i=np.random.randint(len(X))#从训练集中随机抽取一个样本
            ywx=w.T.dot(X.values[1])*y[i]#计算y_i w^T x_i
            
            if ywx<1:#进行指示函数的判断
                w=w-learning_rate*lam*w+learning_rate*y[i]*X.values[i]#更新参数
            else:
                w=w-learning_rate*lam*w
            
            for i in range(len(X)):
                ywx=w.T.dot(X.values[i])*y[i]
                if ywx<=1:#根据样本是否位于间隔附近判断是否为支持向量
                    support_vectors.append(X.values[i])
                    
            return w,support_vectors

    2、线性支持向量机的正则化项通常不包括截距项,可以将数据进行中心化,再调用上述代码

    #对训练集数据进行归一化,则模型无需再计算截距项
    X=data[["x1","x2"]].apply(lambda x:x-x.mean())
    #训练集标签
    y=data["label"]
    w,support_vectors=linear_svm(X,y,lam=0.05,max_iter=5000)

    3、将得到的超平面可视化,同时将两个函数间隔为1的线也绘制出来,对于所有不满足约束条件的样本,使用圆圈标记出来

    plt.figure(figsize=(8,8))#设置图片尺寸
    
    #绘制两类样本点
    X_pos=X[y==1]
    X_neg=X[y==-1]
    plt.scatter(data_pos["x1"],data_pos["x2"],c="#E4007F",marker="^")#类别为1的数据绘制成洋红色
    plt.scatter(data_neg["x1"],data_neg["x2"],c="#007979",marker="o")#类别为-1的数据绘制成深绿色
    
    #绘制超平面
    x1=np.linspace(-6,6,50)
    x2=-w[0]*x1/w[1]
    plt.plot(x1,x2,c="gray")
    
    #绘制两个间隔超平面
    plt.plot(x1,-(w[0]*x1+1)/w[1],"--",c="#007979")
    plt.plot(x1,-(w[0]*x1-1)/w[1],"--",c="#E4007E")
    
    #标注支持向量
    for x in support_vectors:
        plt.plot(x[0],x[1],"ro",linewidth=2,markersize=12,markerfacecolor='none')
        
    #添加轴标签,限制轴范围
    plt.xlabel("$x_1$")#设置横轴标签
    plt.ylabel("$x_2$")#设置纵轴标签
    
    plt.xlim(-6,6)#设置横轴显示范围
    plt.ylim(-2,2)#设置纵轴显示范围
  • 相关阅读:
    打印杨辉三角
    插值排序
    各种冒泡排序法
    Linux系统命令符01
    2.1博客系统 |基于form组件和Ajax实现注册登录
    python面试笔试题,你都会了吗?快来复习
    1.2博客系统 |登录页| 验证码
    1.1博客系统| 表结构
    第五章:5.2面向对象-绑定方法和非绑定方法| 内置方法 |元类
    11.Django|中间件
  • 原文地址:https://www.cnblogs.com/hhjing/p/14342003.html
Copyright © 2011-2022 走看看