zoukankan      html  css  js  c++  java
  • 学习进度笔记32

    观看Tensorflow案例实战视频课程22 迭代及测试网络效果

    #生成一个训练batch
    def get_next_batch(batch_size=128):
        batch_x=np.zeros([batch_size,IMAGE_HEIGHT*IMAGE_WIDTH])
        batch_y=np.zeros([batch_size,MAX_CAPTCHA*CHAR_SET_LEN])
    
        #有时生成图像大小不是(60,160,3)
        def wrap_gen_captcha_text_and_image():
            while True:
                text,image=gen_captcha_text_and_image()
                if image.shape==(60,160,3):
                    return text,image
    
        for i in range(batch_size):
            text,image=wrap_gen_captcha_text_and_image()
            image=convert2gray(image)
    
            batch_x[i,:]=image.flatten()/255#(image.flatten()-128)/128 mean为0
            batch_y[i,:]=text2vec(text)
    
            return batch_x,batch_y
    
    # 训练
    def train_crack_captcha_cnn():
        output=crack_captcha_cnn()
        loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output,Y))
        optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)
        predict=tf.reshape(output,[-1,MAX_CAPTCHA,CHAR_SET_LEN])
        max_idx_p=tf.argmax(predict,2)
        max_idx_l=tf.argmax(tf.reshape(Y,[-1,MAX_CAPTCHA,CHAR_SET_LEN]),2)
        correct_pred=tf.equal(max_idx_p,max_idx_l)
        accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
    
        saver=tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
    
            step=0
            while True:
                batch_x,batch_y=get_next_batch(64)
                _,loss_=sess.run([optimizer,loss],feed_dict={X:batch_x,Y:batch_y,keep_prob:0.75})
                print(step,loss_)
    
                # 每100 step计算一次准确率
                if step%100==0:
                    batch_x_test,batch_y_test=get_next_batch(100)
                    acc=sess.run(accuracy,feed_dict={X:batch_x_test,Y:batch_y_test,keep_prob:1.})
                    print(step,acc)
                    # 如果准确率大于50%,保存模型,完成训练
                    if acc>0.85:
                        saver.save(sess,"./model/crack_captcha.model",global_step=step)
                        break
    
                step+=1
    def crack_captcha(captcha_image):
        output=crack_captcha_cnn()
    
        saver=tf.train.Saver()
        with tf.Session() as sess:
            saver.restore(sess,"./model/crack_captcha.model-1500")
    
            predict=tf.argmax(tf.reshape(output,[-1,MAX_CAPTCHA,CHAR_SET_LEN]),2)
            text_list=sess.run(predict,feed_dict={X:[captcha_image],keep_prob:1})
            text=text_list[0].tolist()
            return text
    if __name__=='__main__':
        #train=0
        train = 1
        if train==0:
            number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
            #alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v', 'w', 'x', 'y', 'z']
            #ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V', 'W', 'X', 'Y', 'Z']
    
            text,image=gen_captcha_text_and_image()
            print("验证码图像channel:",image.shape)#(60,160,3)
            #图像大小
            IMAGE_HEIGHT=60
            IMAGE_WIDTH=160
            MAX_CAPTCHA=len(text)
            print("验证码文本最长字符数",MAX_CAPTCHA)
            #文本转向量
            #char_set=number+alphabet+ALPHACET+['_']#如果验证码长度小于4,'_'用来补充
            char_set=number
            CHAR_SET_LEN=len(char_set)
    
            X=tf.placeholder(tf.float32,[None,IMAGE_HEIGHT*IMAGE_WIDTH])
            Y=tf.placeholder(tf.float32,[None,MAX_CAPTCHA*CHAR_SET_LEN])
            keep_prob=tf.placeholder(tf.float32)# dropout
    
            train_crack_captcha_cnn()
        if train==1:
            number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
            IMAGE_HEIGHT=60
            IMAGE_WIDTH=160
            char_set=number
            CHAR_SET_LEN=len(char_set)
    
            text,image=gen_captcha_text_and_image()
    
            f=plt.figure()
            ax=f.add_wuplot(111)
            ax.text(0.1, 0.9, text, ha='center', va='center', transform=ax.transAxes)
            plt.imshow(image)
    
            plt.show()
    
            MAX_CAPTCHA=len(text)
            image=convert2gray(image)
            image=image.flatten()/255
    
            X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH])
            Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN])
            keep_prob = tf.placeholder(tf.float32)  # dropout
    
            predict_text=crack_captcha(image)
            print("正确:() 预测:()".format(text,predict_text))
  • 相关阅读:
    Zk学习笔记——权限控制
    guava学习笔记
    Elasticsearch学习笔记——别名
    Kafka学习笔记——存储结构
    分布式协议——Paxos、Raft和ZAB
    图解 Java 中的数据结构及原理!
    牛逼哄哄的 Lambda 表达式,简洁优雅就是生产力!
    你必须了解Spring的生态
    盘点 35 个 Apache 顶级项目,我拜服了…
    前后端分离如何做权限控制设计?
  • 原文地址:https://www.cnblogs.com/zql-42/p/14632781.html
Copyright © 2011-2022 走看看