zoukankan      html  css  js  c++  java
  • mooc机器学习第九天-手写数字分类实例(MLP,K近邻)

    1.mooc数据任务简介

     

     

     2.代码

    (1)MLP

    import numpy as np     #导入numpy工具包
    from os import listdir #使用listdir模块,用于访问本地文件
    from sklearn.neural_network import MLPClassifier 
     
    def img2vector(fileName):    
        retMat = np.zeros([1024],int) #定义返回的矩阵,大小为1*1024
        fr = open(fileName)           #打开包含32*32大小的数字文件 
        lines = fr.readlines()        #读取文件的所有行
        for i in range(32):           #遍历文件所有行
            for j in range(32):       #并将01数字存放在retMat中     
                retMat[i*32+j] = lines[i][j]    
        return retMat
     
    def readDataSet(path):    
        fileList = listdir(path)    #获取文件夹下的所有文件 
        numFiles = len(fileList)    #统计需要读取的文件的数目
        dataSet = np.zeros([numFiles,1024],int) #用于存放所有的数字文件
        hwLabels = np.zeros([numFiles,10])      #用于存放对应的one-hot标签
        for i in range(numFiles):   #遍历所有的文件
            filePath = fileList[i]  #获取文件名称/路径      
            digit = int(filePath.split('_')[0])  #通过文件名获取标签      
            hwLabels[i][digit] = 1.0        #将对应的one-hot标签置1
            dataSet[i] = img2vector(path +'/'+filePath) #读取文件内容   
        return dataSet,hwLabels
     
    #read dataSet
    train_dataSet, train_hwLabels = readDataSet('trainingDigits')
     
    clf = MLPClassifier(hidden_layer_sizes=(100,),
                        activation='logistic', solver='adam',
                        learning_rate_init = 0.0001, max_iter=2000)
    print(clf)
    clf.fit(train_dataSet,train_hwLabels)
     
    #read  testing dataSet
    dataSet,hwLabels = readDataSet('testDigits')
    res = clf.predict(dataSet)   #对测试集进行预测
    error_num = 0                #统计预测错误的数目
    num = len(dataSet)           #测试集的数目
    for i in range(num):         #遍历预测结果
        #比较长度为10的数组,返回包含01的数组,0为不同,1为相同
        #若预测结果与真实结果相同,则10个数字全为1,否则不全为1
        if np.sum(res[i] == hwLabels[i]) < 10: 
            error_num += 1                     
    print("Total num:",num," Wrong num:", 
          error_num,"  WrongRate:",error_num / float(num))
    

    (2)K近邻

    import numpy as np     #导入numpy工具包
    from os import listdir #使用listdir模块,用于访问本地文件
    from sklearn import neighbors
     
    def img2vector(fileName):    
        retMat = np.zeros([1024],int) #定义返回的矩阵,大小为1*1024
        fr = open(fileName)           #打开包含32*32大小的数字文件 
        lines = fr.readlines()        #读取文件的所有行
        for i in range(32):           #遍历文件所有行
            for j in range(32):       #并将01数字存放在retMat中     
                retMat[i*32+j] = lines[i][j]    
        return retMat
     
    def readDataSet(path):    
        fileList = listdir(path)    #获取文件夹下的所有文件 
        numFiles = len(fileList)    #统计需要读取的文件的数目
        dataSet = np.zeros([numFiles,1024],int)    #用于存放所有的数字文件
        hwLabels = np.zeros([numFiles])#用于存放对应的标签(与神经网络的不同)
        for i in range(numFiles):      #遍历所有的文件
            filePath = fileList[i]     #获取文件名称/路径   
            digit = int(filePath.split('_')[0])   #通过文件名获取标签     
            hwLabels[i] = digit        #直接存放数字,并非one-hot向量
            dataSet[i] = img2vector(path +'/'+filePath)    #读取文件内容 
        return dataSet,hwLabels
     
    #read dataSet
    train_dataSet, train_hwLabels = readDataSet('trainingDigits')
    knn = neighbors.KNeighborsClassifier(algorithm='kd_tree', n_neighbors=3)
    knn.fit(train_dataSet, train_hwLabels)
     
    #read  testing dataSet
    dataSet,hwLabels = readDataSet('testDigits')
     
    res = knn.predict(dataSet)  #对测试集进行预测
    error_num = np.sum(res != hwLabels) #统计分类错误的数目
    num = len(dataSet)          #测试集的数目
    print("Total num:",num," Wrong num:", 
          error_num,"  WrongRate:",error_num / float(num))
    

    3.小结

    通过调整参数,可以看出,在较小(稀疏矩阵)的数据集时,K近邻的准确率更高,而全联接的学习率,

    拟合次数,神经元个数三者都会影响拟合效果,在这个数据集上,2000次已经是比较合适的,过多容易过拟合。

  • 相关阅读:
    poj 2411 Mondriaan's Dream 骨牌铺放 状压dp
    zoj 3471 Most Powerful (有向图)最大生成树 状压dp
    poj 2280 Islands and Bridges 哈密尔顿路 状压dp
    hdu 3001 Travelling 经过所有点(最多两次)的最短路径 三进制状压dp
    poj 3311 Hie with the Pie 经过所有点(可重)的最短路径 floyd + 状压dp
    poj 1185 炮兵阵地 状压dp
    poj 3254 Corn Fields 状压dp入门
    loj 6278 6279 数列分块入门 2 3
    VIM记事——大小写转换
    DKIM支持样本上传做检测的网站
  • 原文地址:https://www.cnblogs.com/cheflone/p/13332236.html
Copyright © 2011-2022 走看看