观看Tensorflow案例实战视频课程22 迭代及测试网络效果
#生成一个训练batch def get_next_batch(batch_size=128): batch_x=np.zeros([batch_size,IMAGE_HEIGHT*IMAGE_WIDTH]) batch_y=np.zeros([batch_size,MAX_CAPTCHA*CHAR_SET_LEN]) #有时生成图像大小不是(60,160,3) def wrap_gen_captcha_text_and_image(): while True: text,image=gen_captcha_text_and_image() if image.shape==(60,160,3): return text,image for i in range(batch_size): text,image=wrap_gen_captcha_text_and_image() image=convert2gray(image) batch_x[i,:]=image.flatten()/255#(image.flatten()-128)/128 mean为0 batch_y[i,:]=text2vec(text) return batch_x,batch_y # 训练 def train_crack_captcha_cnn(): output=crack_captcha_cnn() loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output,Y)) optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss) predict=tf.reshape(output,[-1,MAX_CAPTCHA,CHAR_SET_LEN]) max_idx_p=tf.argmax(predict,2) max_idx_l=tf.argmax(tf.reshape(Y,[-1,MAX_CAPTCHA,CHAR_SET_LEN]),2) correct_pred=tf.equal(max_idx_p,max_idx_l) accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32)) saver=tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) step=0 while True: batch_x,batch_y=get_next_batch(64) _,loss_=sess.run([optimizer,loss],feed_dict={X:batch_x,Y:batch_y,keep_prob:0.75}) print(step,loss_) # 每100 step计算一次准确率 if step%100==0: batch_x_test,batch_y_test=get_next_batch(100) acc=sess.run(accuracy,feed_dict={X:batch_x_test,Y:batch_y_test,keep_prob:1.}) print(step,acc) # 如果准确率大于50%,保存模型,完成训练 if acc>0.85: saver.save(sess,"./model/crack_captcha.model",global_step=step) break step+=1 def crack_captcha(captcha_image): output=crack_captcha_cnn() saver=tf.train.Saver() with tf.Session() as sess: saver.restore(sess,"./model/crack_captcha.model-1500") predict=tf.argmax(tf.reshape(output,[-1,MAX_CAPTCHA,CHAR_SET_LEN]),2) text_list=sess.run(predict,feed_dict={X:[captcha_image],keep_prob:1}) text=text_list[0].tolist() return text if __name__=='__main__': #train=0 train = 1 if train==0: number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] #alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v', 'w', 'x', 'y', 'z'] #ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V', 'W', 'X', 'Y', 'Z'] text,image=gen_captcha_text_and_image() print("验证码图像channel:",image.shape)#(60,160,3) #图像大小 IMAGE_HEIGHT=60 IMAGE_WIDTH=160 MAX_CAPTCHA=len(text) print("验证码文本最长字符数",MAX_CAPTCHA) #文本转向量 #char_set=number+alphabet+ALPHACET+['_']#如果验证码长度小于4,'_'用来补充 char_set=number CHAR_SET_LEN=len(char_set) X=tf.placeholder(tf.float32,[None,IMAGE_HEIGHT*IMAGE_WIDTH]) Y=tf.placeholder(tf.float32,[None,MAX_CAPTCHA*CHAR_SET_LEN]) keep_prob=tf.placeholder(tf.float32)# dropout train_crack_captcha_cnn() if train==1: number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] IMAGE_HEIGHT=60 IMAGE_WIDTH=160 char_set=number CHAR_SET_LEN=len(char_set) text,image=gen_captcha_text_and_image() f=plt.figure() ax=f.add_wuplot(111) ax.text(0.1, 0.9, text, ha='center', va='center', transform=ax.transAxes) plt.imshow(image) plt.show() MAX_CAPTCHA=len(text) image=convert2gray(image) image=image.flatten()/255 X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH]) Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN]) keep_prob = tf.placeholder(tf.float32) # dropout predict_text=crack_captcha(image) print("正确:() 预测:()".format(text,predict_text))