zoukankan      html  css  js  c++  java
  • [MNIST数据集]输入图像的预处理

    因为MNIST数据是28*28的黑底白字图像,而且输入时要将其拉直,也就是可以看成1*784的二维张量(张量的值在0~1之间),所以我们要对图片进行预处理操作,是图片能被网络识别。

    以下是代码部分

    import tensorflow as tf
    import numpy as np
    from PIL import Image
    import backward as bw
    import forward as fw
    
    def restore(testPicArr):
        with tf.Graph().as_default() as g:
            x = tf.placeholder(tf.float32, [None, fw.INPUT_NODES])
            y_ = tf.placeholder(tf.float32, [None, fw.OUTPUT_NODES])
            y = fw.get_y(x, None)
            preValue = tf.arg_max(y, 1)
            
            ema = tf.train.ExponentialMovingAverage(bw.MOVING_ARVERAGE_DECAY)
            ema_restore = ema.variables_to_restore()
            saver = tf.train.Saver(ema_restore)
            
            with tf.Session() as sess:
                tf.logging.set_verbosity(tf.logging.WARN)#降低警告等级
                ckpt = tf.train.get_checkpoint_state("./model/")
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    
                    preValue = sess.run(preValue, feed_dict = {x: testPicArr})
                    return preValue
                else:
                    print("NO!!!")
                    return -1
        
    def pre_pic(picName):
        img = Image.open(picName)
        reIm = img.resize((28, 28), Image.ANTIALIAS)
        im_arr = np.array(reIm.convert('L'))#变为灰度图
        threshold = 50#阈值,将图片二值化操作
        for i in range(28):
            for j in range(28):
                im_arr[i][j] = 255 - im_arr[i][j]#进行反色处理
                if(im_arr[i][j] < threshold):
                    im_arr[i][j] = 0
                else: im_arr[i][j] = 255
        
        nm_arr = im_arr.reshape([1,784])
        nm_arr = nm_arr.astype(np.float32)#类型转换
        img_ready = np.multiply(nm_arr, 1.0/255.0)#把值变为0~1之间的数值
        
        return img_ready
    
    def app():
        testNum = input("Input the number of test pictutre:")
        for i in range(int(testNum)):
            testPic = input("the path of test picture:")
            testPicArr = pre_pic(testPic)
            preValue = restore(testPicArr)
            print("The prediction number is :" , preValue)
            
    def main():
        app()
        
    if __name__ == '__main__':
        main()
                
    记录点点滴滴
  • 相关阅读:
    正则表达式
    浏览器 User-Agent 大全
    python3爬虫开发实战 第六课 爬虫基本流程
    python3爬虫开发实战 第五课 常用库的安装
    python3爬虫开发实战 第四课 MySQL
    python3爬虫开发实战 第三课 Redis数据库
    python3爬虫开发实战 第二课 MongoDB安装
    python3爬虫开发实战 第一课 python安装和Pycharm安装
    批处理——数据库
    Aop所需包
  • 原文地址:https://www.cnblogs.com/1by1/p/10226900.html
Copyright © 2011-2022 走看看