zoukankan      html  css  js  c++  java
  • python数据分析——手写数字识别

    import numpy as np
    # bmp 图片后缀
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    from sklearn.neighbors import KNeighborsClassifier
    
     
    • 提炼样本数据
    In [6]:
    img_arr = plt.imread('./data/3/3_100.bmp')
    plt.imshow(img_arr)
    
    Out[6]:
    <matplotlib.image.AxesImage at 0x2066baa2780>
     
    In [ ]:
    #./data/3/3_100.bmp
    
    In [59]:
    feature = []
    target = []
    for i in range(0,10):
        for j in range(1,501):
            img_path = './data/'+str(i)+'/'+str(i)+'_'+str(j)+'.bmp'
            img_arr = plt.imread(img_path)
            feature.append(img_arr)
            target.append(i)
    
    In [60]:
    len(feature)
    
    Out[60]:
    5000
     
    • 样本数据的提取
    In [61]:
    feature = np.array(feature)
    
    In [62]:
    feature.shape
    
    Out[62]:
    (5000, 28, 28)
    In [63]:
    # 特征数据必须保证是二维的
    # feature是一个三维数组(执行将维操作)
    feature = feature.reshape(5000,28*28)
    
    In [64]:
    feature.shape
    
    Out[64]:
    (5000, 784)
    In [65]:
    target = np.array(target)
    
     
    • 将样本打乱
    In [66]:
    np.random.seed(3)
    np.random.shuffle(feature)
    np.random.seed(3)
    np.random.shuffle(target)
    
     
    • 获取训练数据和测试数据
    In [67]:
    x_train = feature[:4950]
    y_train = target[:4950]
    x_test = feature[-50:]
    y_test = target[-50:]
    
     
    • 实例化模型对象,训练
    In [68]:
    knn = KNeighborsClassifier(n_neighbors=30)
    knn.fit(x_train,y_train)
    knn.score(x_train,y_train)
    
    Out[68]:
    0.9195959595959596
    In [69]:
    print('预测分类:',knn.predict(x_test))
    print('真实数据:',y_test)
    
     
    预测分类: [4 5 7 9 7 5 7 6 8 6 1 1 3 4 8 4 1 0 1 2 0 5 8 6 5 9 3 9 1 8 9 6 4 1 5 0 8
     7 7 1 5 3 5 5 6 1 1 3 6 3]
    真实数据: [4 5 7 9 7 5 7 6 8 6 4 1 3 4 8 4 2 0 1 2 0 5 8 6 5 9 3 9 1 8 9 6 4 1 5 2 8
     7 7 2 5 3 5 5 6 1 1 3 6 3]
    
     
    • 模型的保存
    In [82]:
    from sklearn.externals import joblib
    
    In [84]:
    joblib.dump(knn,'./digist.m')
    
    Out[84]:
    ['./digist.m']
     
    • 加载模型
    In [85]:
    knn = joblib.load('./digist.m')
    
     
    • 识别外部图片数字
    In [16]:
    #外部图片的识别
    img_arr = plt.imread('./数字.jpg')
    plt.imshow(img_arr)
    
    Out[16]:
    <matplotlib.image.AxesImage at 0x1d68f0e9d68>
     
    In [17]:
    five_arr = img_arr[90:155,80:135]
    
    In [18]:
    plt.imshow(five_arr)
    
    Out[18]:
    <matplotlib.image.AxesImage at 0x1d68f147438>
     
    In [19]:
    five_arr.shape
    
    Out[19]:
    (65, 55, 3)
    In [20]:
    #five数组是三维的,需要进行降维,舍弃第三个表示颜色的维度
    five_arr = five_arr.mean(axis=2)
    
    In [21]:
    five_arr.shape
    
    Out[21]:
    (65, 55)
     
    • 对图片进行等比压缩
    In [22]:
    import scipy.ndimage as ndimage
    
    In [23]:
    five = ndimage.zoom(five_arr,zoom = (28/65,28/55))
    
    In [24]:
    five.shape
    
    Out[24]:
    (28, 28)
    In [81]:
    # 转换为(1,784)形式
    knn.predict(five.reshape(1,784))
    
    Out[81]:
    array([5])
  • 相关阅读:
    小白详细解析C#反射特性实例
    几种快速排序算法实现
    Redis中算法之——Raft算法
    redis中算法之——MurmurHash2算法
    关于typedef的用法
    gdb调试工具常用命令
    gcc 常用命令
    Linux 远程登录ssh服务器
    Linux 构建ftp服务器
    知乎话题结构以及相关内容抓取二(Redis存储)
  • 原文地址:https://www.cnblogs.com/bilx/p/11647989.html
Copyright © 2011-2022 走看看