zoukankan      html  css  js  c++  java
  • Kaggle Digit Recognizer 练习

    本文代码主要是为了练习Kaggle流程,精确度不高。

    main.py

     1 #encoding:utf-8
     2 from functions import *
     3 from sklearn import neighbors
     4 
     5 #读取训练集数据
     6 trainData,trainLabel=readTrainData('train.csv')
     7 #读取测试集数据
     8 testData=readTestData('test.csv')
     9 
    10 #KNN算法中的k值,即最近邻的个数
    11 n_neighbors=5 
    12 #KNN算法的加权的两种方式
    13 weights=['uniform', 'distance'] 
    14 #定义一个knn算法
    15 clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights[1])
    16 #训练模型
    17 clf.fit(trainData,trainLabel) 
    18 #将模型应用于测试集
    19 testLabel=clf.predict(testData)
    20 #测试结果输出为Kaggle要求的格式
    21 getResult(testLabel)

    functions.py

     1 #encoding:utf-8
     2 import numpy as np
     3 import pandas as pd
     4 from pandas import Series,DataFrame
     5 
     6 #定义读取训练集数据的函数,该函数将训练数据集中的项和标签分开
     7 def readTrainData(fileName):
     8     #读取训练数据集
     9     df=pd.read_csv(fileName)
    10     
    11     #训练数据集的第一列为标签列,此时为Series数据结构
    12     trainLabel=df['label']     
    13     #将Series转换为ndarray数据结构,长度为42000,每一个ndarray项为一个0~9的数字
    14     trainLabel=trainLabel.values
    15     
    16     #删除训练数据集的标签列,余下的为训练数据集,42000*784   
    17     del df['label']
    18     #此时df仍为DataFrame数据结构,所以在此时对df进行简化处理
    19     #简化函数定义见后面代码
    20     df=simplify(df) 
    21     #将训练数据集从DataFrame数据结构转换为ndarray数据结构,
    22     #长度为42000,每一个ndarray项为一个长度为784的list
    23     trainData=df.values             
    24     return trainData,trainLabel
    25 
    26     
    27 def readTestData(fileName):
    28     #读取测试数据集
    29     df=pd.read_csv(fileName)
    30     #对训练集进行了简化处理,要对测试集进行同样的处理
    31     df=simplify(df) 
    32     #测试数据集由于没有标签列,
    33     #所以直接将测试数据集从DataFrame数据结构转换为ndarray数据结构
    34     testData=df.values
    35     return testData
    36 
    37 #为了减少模型训练时的计算量,可以对测试数据项和训练数据项进行简化处理
    38 #思想:由于是识别字迹,所以是什么颜色不是重点,重点是RGB值是0或不是0
    39 #所以,将非0的值都设置为1 
    40 #注意:传入的数据格式为DataFrame格式   
    41 def simplify(data):
    42     #DataFrame的applymap方法,对元素进行函数应用
    43     data=data.applymap(returnOneOrZero)
    44     return data
    45  
    46 #将非0的值替换为1,注意要return 0,否则结果0所在的位置都会变成NaN   
    47 def returnOneOrZero(num):
    48     if num:
    49         return 1
    50     else:
    51         return 0
    52     
    53     
    54 def getResult(testLabel):
    55     #通过scikit-learn得到的测试标签为ndarray格式,将其转换为Series数据结构
    56     testLabel=Series(testLabel)
    57     #加表头ImageId和Label。即第一列为ImageId列,第二列为Label,np.arange要加1
    58     df={'ImageId':np.arange(28000)+1,'Label':testLabel}
    59     #将df转换为DataFrame数据结构
    60     df=DataFrame(df)
    61     #讲结果输出到csv格式的文本中
    62     df.to_csv('result.csv',index=False) 

    analysis.py

     1 # -*- coding: utf-8 -*-
     2 import pandas as pd
     3 def getDiff(benchmark,result,test):
     4     list=[]
     5     for i in range(28000):
     6         if benchmark['Label'][i]==result['Label'][i]:
     7             list.append(i)
     8     benchmark=benchmark.drop(list)
     9     result=result.drop(list)
    10     test=test.drop(list)
    11     test['Label_right']=benchmark['Label']
    12     test['Label_wrong']=result['Label']
    13     return test
    14 
    16 benchmark=pd.read_csv('rf_benchmark.csv')
    17 result=pd.read_csv('result.csv')
    18 test=pd.read_csv('test.csv')
    19 errorAnalysis=getDiff(benchmark,result,test)
    20 print len(errorAnalysis)
    21 errorAnalysis.to_csv('errorAnalysis.csv',index=False)
  • 相关阅读:
    时间戳(1532249295.179) 转日期格式(2018/07/22 16:48:15 179)
    iscroll.js右侧可滑动的菜单,点击每个菜单都会出现本菜单的详情
    canvas绘制的文字如何换行
    移动端H5页面禁止长按复制和去掉点击时高亮
    一列宽度不缩放,一列宽度弹性缩放,且超出后显示省略号
    js钩子机制(hook)
    mCustomScrollbar.js 漂亮的滚动条插件 适应内容自动更新
    axios.js 实例 -----$.ajax的替代方案
    用 async/await 来处理异步实例
    C#入门经典第18章-WEB编程
  • 原文地址:https://www.cnblogs.com/PistonType/p/5390459.html
Copyright © 2011-2022 走看看