zoukankan      html  css  js  c++  java
  • cnn 卷积神经网络 人脸识别

      卷积网络博大精深,不同的网络模型,跑出来的结果是不一样,在不知道使用什么网络的情况下跑自己的数据集时,我建议最好去参考基于cnn的手写数字识别网络构建,在其基础上进行改进,对于一般测试数据集有很大的帮助。

    分享一个网络构架和一中训练方法:

    # coding:utf-8
    import os
    import tensorflow as tf
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    
    # cnn模型高度抽象特征
    def cnn_face_discern_model(X_,Y_):
        weights = {
            "wc1":tf.Variable(tf.random_normal([3,3,1,64],stddev=0.1)),
            "wc2":tf.Variable(tf.random_normal([5,5,64,128],stddev=0.1)),
            "wd3":tf.Variable(tf.random_normal([7*7*128,1024],stddev=0.1)),
            "wd4": tf.Variable(tf.random_normal([1024, 12], stddev=0.1))
        }
        biases = {
            "bc1":tf.Variable(tf.random_normal([64],stddev=0.1)),
            "bc2":tf.Variable(tf.random_normal([128],stddev=0.1)),
            "bd3": tf.Variable(tf.random_normal([1024],stddev=0.1)),
            "bd4": tf.Variable(tf.random_normal([12],stddev=0.1))
        }
        x_input =  tf.reshape(X_,shape=[-1,28,28,1])
    
        # 第一层卷积层
        _conv1 = tf.nn.conv2d(x_input,weights["wc1"],strides=[1,1,1,1],padding="SAME")
        _conv1_ = tf.nn.relu(tf.nn.bias_add(_conv1,biases["bc1"]))
        # 第一层池化层
        _pool1 = tf.nn.max_pool(_conv1_,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
        # 第一层失活层
        _pool1_dropout = tf.nn.dropout(_pool1,0.7)
    
        # 第二层卷积层
        _conv2 = tf.nn.conv2d(_pool1_dropout,weights["wc2"],strides=[1,1,1,1],padding="SAME")
        _conv2_ = tf.nn.relu(tf.nn.bias_add(_conv2,biases["bc2"]))
        # 第二层池化层
        _pool2 = tf.nn.max_pool(_conv2_,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
        # 第二层失活层
        _pool2_dropout =  tf.nn.dropout(_pool2,0.7)
    
        # 使用全连接层提取抽象特征
        # 全连接层1
        _densel =  tf.reshape(_pool2_dropout,[-1,weights["wd3"].get_shape().as_list()[0]])
        _y1 = tf.nn.relu(tf.add(tf.matmul(_densel,weights["wd3"]),biases["bd3"]))
        _y2 = tf.nn.dropout(_y1,0.7)
        # 全连接层2
        out = tf.add(tf.matmul(_y2,weights["wd4"]),biases["bd4"])
    
        # 损失函数 loss
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_, logits=out))  # 计算交叉熵
    
        # 优化目标 optimizing
        optimizing = tf.train.AdamOptimizer(0.001).minimize(loss)  # 使用adam优化器来以0.0001的学习率来进行微调
    
    
    
        # 精确度 accuracy
        correct_prediction = tf.equal(tf.argmax(Y_, 1), tf.argmax(out, 1))  # 判断预测标签和实际标签是否匹配
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    
    
    
        return {
            "loss":loss,
            "optimizing":optimizing,
            "accuracy":accuracy,
            "out":out
        }
    

      

    批量训练方法:

    # 开始准备训练cnn
    X = tf.placeholder(tf.float32,[None,28,28,1])
    # 这个12属于人脸类别,一共有几个id
    Y = tf.placeholder(tf.float32, [None,12])
    
    
    # 实例化模型
    cnn_model = cnn_face_discern_model(X,Y)
    
    loss,optimizing,accuracy,out = cnn_model["loss"],cnn_model["optimizing"],cnn_model["accuracy"],cnn_model["out"]
    
    
    # 启动训练模型
    bsize = 960/60
    
    with tf.Session() as sess:
        # 实例所有参数
        sess.run(tf.global_variables_initializer())
        for epoch in range(100):
            for i in range(15):
                x_bsize,y_bsize = x_train[i*60:i*60+60,:,:,:],y_train[i*60:i*60+60,:]
                sess.run(optimizing,feed_dict={X:x_bsize,Y:y_bsize})
    
            if (epoch+1)%10==0:
                los = sess.run(loss,feed_dict={X:x_test,Y:y_test})
                acc = sess.run(accuracy,feed_dict={X:x_test,Y:y_test})
    
                print("epoch:%s loss:%s accuracy:%s"%(epoch,los,acc))
    
        score= sess.run(accuracy,feed_dict={X:x_test,Y:y_test})
    
        y_pred = sess.run(out,feed_dict={X:x_test})
    
        # 这个是类别,测试集预测出来的类别。
        y_pred = np.argmax(y_pred,axis=1)
    
        print("最后的精确度为:%s"%score)
    

      

  • 相关阅读:
    Oracle学习(四)--sql及sql分类讲解
    Oracle学习(三)--数据类型及常用sql语句
    Oracle学习(二)--启动与关闭
    Tomcat学习笔记--启动成功访问报404错误
    有关Transaction not successfully started问题解决办法
    百度富文本编辑器UEditor1.3上传图片附件等
    hibernate+junit测试实体类生成数据库表
    js登录与注册验证
    SVN安装配置与使用
    [LeetCode] #38 Combination Sum
  • 原文地址:https://www.cnblogs.com/wuzaipei/p/10474177.html
Copyright © 2011-2022 走看看