zoukankan      html  css  js  c++  java
  • 基于tensorflow的简单鼠标键盘识别


    import cv2 as cv
    import tensorflow as tf
    import numpy as np
    import random

    ##以下为数据预处理,分类为cata,总共样本为cata*num_batch,总共图像为cata*num_img
    cata=2 #需要分的类别
    num_img=49 #图像个数
    #该函数返回x与y,输入批量,产生cata*num_batch
    def XANDY(num_batch):

    x_mouse=np.zeros([num_batch,500,500,1]) #保存鼠标图片矩阵
    x_keyboard=np.zeros([num_batch,500,500,1]) #保存键盘图片矩阵
    temp_mouse=random.sample(range(0,num_img),num_batch)
    temp_keyboard=random.sample(range(0,num_img),num_batch)
    for i in range(num_batch):
    img_mouse1 = cv.imread('C:\Users\HHQ\Desktop\tangjun\mouse\data_mouse\'+str(temp_mouse[i])+'.PNG', cv.IMREAD_GRAYSCALE)
    img_mouse=cv.resize(img_mouse1,(500,500))
    x_mouse[i,:,:,0]=img_mouse
    img_keyboard1 = cv.imread('C:\Users\HHQ\Desktop\tangjun\mouse\data_keyboard\'+str(temp_keyboard [i])+'.bmp', cv.IMREAD_GRAYSCALE)
    img_keyboard = cv.resize(img_keyboard1, (500, 500))
    x_keyboard [i,:,:,0] = img_keyboard

    xx=np.vstack((x_mouse,x_keyboard))
    #表签中0表示鼠标,1表示键盘
    y_0=np.zeros([num_batch,1])
    y_1=np.ones([num_batch,1])
    y_mouse=np.hstack((y_1,y_0))
    y_keyboard=np.hstack((y_0,y_1))
    yy_=np.vstack((y_mouse,y_keyboard)) #标签为二维数组,行保存样本数量,列保存分类
    return xx,yy_





    x=tf.placeholder(dtype=tf.float32,shape=[None ,500,500,1])
    y_=tf.placeholder(dtype=tf.float32,shape=[None,cata])
    #建立卷积
    #第一层卷积
    W_cov1=tf.Variable(tf.truncated_normal([5,5,1,32],stddev=0.1),dtype=tf.float32)
    B_cov1=tf.Variable(tf.truncated_normal([32],stddev=0.1),dtype=tf.float32)
    A_cov1=tf.nn.relu(tf.nn.conv2d(x,W_cov1,strides=[1,1,1,1],padding='SAME')+B_cov1)
    P_cov1=tf.nn.max_pool(A_cov1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='VALID')
    #得到250*250*32维度的图像

    #第二层卷积
    W_cov2=tf.Variable(tf.truncated_normal([5,5,32,64],stddev=0.1),dtype=tf.float32)
    B_cov2=tf.Variable(tf.truncated_normal([64],stddev=0.1),dtype=tf.float32)
    A_cov2=tf.nn.relu(tf.nn.conv2d(P_cov1,W_cov2,strides=[1,1,1,1],padding='SAME')+B_cov2)
    # #第三层卷积
    # W_cov3=tf.Variable(tf.truncated_normal())




    # 建立全连接层,识别2物体
    w=tf.Variable(tf.zeros([250*250*64,cata]),dtype= tf.float32)
    b=tf.Variable(tf.zeros([cata]),dtype=tf.float32)
    x_reshape=tf.reshape(A_cov2,[-1,250*250*64])
    y=tf.matmul(x_reshape,w)+b

    #定义交叉熵,为了定义损失函数
    loss=tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
    # loss=-tf.reduce_mean(y_*tf.log(y))
    #定义优化器
    # train=tf.train.GradientDescentOptimizer(0.001).minimize(loss)
    # train=tf.train.AdagradDAOptimizer(0.01).minimize(loss)
    train=tf.train.AdamOptimizer(0.001).minimize(loss)
    #定义预测准确率
    predict1=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    predict=tf.reduce_mean(tf.cast(predict1,tf.float32))

    init=tf.initialize_all_variables()
    sess=tf.Session()

    sess.run(init)
    x_pr,y_pr=XANDY(40)

    for i in range(30):
    x_ba,y_ba=XANDY(15)
    sess.run(train,feed_dict={x:x_ba,y_:y_ba})
    accuracy=sess.run(predict, feed_dict={x: x_pr, y_: y_pr})
    print('训练步骤: %d , 训练精度:%g' %(i,accuracy))
























  • 相关阅读:
    OLAP ODS项目的总结 平台选型,架构确定
    ORACLE ORA12520
    ORACLE管道函数
    ORACLE RAC JDBC 配置
    ORACLE RAC OCFS连接产生的错误
    ORACLE 启动和关闭详解
    OLAP ODS项目的总结 起步阶段
    ORACLE RAC 配置更改IP
    ORACLE RAC OCR cann't Access
    ORACLE RAC Debug 之路 CRS0184错误与CRS初始化
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/10986239.html
Copyright © 2011-2022 走看看