zoukankan      html  css  js  c++  java
  • 08-人脸识别-FaceNet-classify.py代码阅读(说明见注释)

    """An example of how to use your own dataset to train a classifier that recognizes people.
    """
    # MIT License
    # 
    # Copyright (c) 2016 David Sandberg
    # 
    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:
    # 
    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.
    # 
    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.
    
    # @ 调用格式:
    # @
    # @ 训练模型记住人脸(不是训练网络,网络在这之前已经先训练好了)。
    # @ ../lfw/ 是lfw数据集经过 mtcnn 截取以后的结果。否则会影响效果(去除数据集中的人脸外部干扰)
    # @     python classifier.py TRAIN ../lfw/ 20170511-185253/ train_20180419_2048.pkl 
    # @
    # @ 测试模型记住人脸的结果。(../data 是测试用的图的路径。)
    # @     python classifier.py CLASSIFY ../data/ 20170511-185253/ train_20180419_2048.pkl 
    
    
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import tensorflow as tf
    import numpy as np
    import argparse
    import facenet
    import os
    import sys
    import math
    import pickle
    from sklearn.svm import SVC
    
    # @ args内中参数见函数 parse_arguments
    def main(args):
    	# @ 声明一个计算图,都这么写,没有就是默认一个。
        with tf.Graph().as_default():
    		# @ 声明一个 Session
            with tf.Session() as sess:
    		
    			# @ Part I
    			# @ 这部分是计算人脸的 embedding 特征。费时。 
    			# @ 
    			
                # @ 加随机数seed,调用np.random.random()的结果都会相同。
                np.random.seed(seed=args.seed)
                
                if args.use_split_dataset:
                    dataset_tmp = facenet.get_dataset(args.data_dir)
                    train_set, test_set = split_dataset(dataset_tmp, args.min_nrof_images_per_class, args.nrof_train_images_per_class)
                    if (args.mode=='TRAIN'):
                        dataset = train_set
                    elif (args.mode=='CLASSIFY'):
                        dataset = test_set
                else:
                    dataset = facenet.get_dataset(args.data_dir)
    
                # Check that there are at least one training image per class
    			# @ cls.image_paths 是每张图的路径,包含文件名。
                for cls in dataset:
                    assert(len(cls.image_paths)>0, 'There must be at least one image for each class in the dataset')            
    
                # @ 分离出图片路径名paths,和类型labels(人脸所属人名)
                paths, labels = facenet.get_image_paths_and_labels(dataset)
                
                print('Number of classes: %d' % len(dataset))
                print('Number of images: %d' % len(paths))
                
                # Load the model
                # @ 这里加的 model 使用于生成人脸的 embedding 特征的网络。
    			# @ 这个网络是事先已经生成好的。
    			# @ 网络可以根据运行的平台,设计成不同大小。比如基于GoogleNet/AlexNet等
    			print('Loading feature extraction model')
                facenet.load_model(args.model)
                
                # Get input and output tensors
    			# @ TensorFlow的参数准备。embeddings 是网络的输出,是后续分类的输入。
                images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
                embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
                phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
                embedding_size = embeddings.get_shape()[1]
                
                # Run forward pass to calculate embeddings
                print('Calculating features for images')
                nrof_images = len(paths) # @ 图片总数
                nrof_batches_per_epoch = int(math.ceil(1.0*nrof_images / args.batch_size))
                emb_array = np.zeros((nrof_images, embedding_size))
                for i in range(nrof_batches_per_epoch):
                    start_index = i*args.batch_size
                    end_index = min((i+1)*args.batch_size, nrof_images)
                    paths_batch = paths[start_index:end_index]
                    images = facenet.load_data(paths_batch, False, False, args.image_size)
                    feed_dict = { images_placeholder:images, phase_train_placeholder:False }
                    emb_array[start_index:end_index,:] = sess.run(embeddings, feed_dict=feed_dict)
                
    			# @ emb_array 是 embedding 结果。一个 embedding 有 18 维。
    			# @ 接下来就是用机器学习的方法分类。
                classifier_filename_exp = os.path.expanduser(args.classifier_filename)
    
    			# @ Part II 也较费时。
    			# @ 这部分是训练分类人脸的机器学习模型,这里使用的SVC,是SVM的一种。
    			# @ 若是 CLASSIFY ,则是加载训练结果,建立 SVC 分类器。
    			
                if (args.mode=='TRAIN'):
                    # Train classifier
    				# @ SVC是SVM的一种Type,是用来的做分类的;同样还有SVR,是SVM的另一种Type,是用来的做回归的。
                    print('Training classifier')
                    model = SVC(kernel='linear', probability=True)
                    model.fit(emb_array, labels) # @ 训练过程
                
    				# @ 训练结束,保存数据
                    # Create a list of class names
                    class_names = [ cls.name.replace('_', ' ') for cls in dataset]
    
                    # Saving classifier model
                    with open(classifier_filename_exp, 'wb') as outfile:
                        pickle.dump((model, class_names), outfile)
                    print('Saved classifier model to file "%s"' % classifier_filename_exp)
                    
                elif (args.mode=='CLASSIFY'):
                    # Classify images
                    print('Testing classifier')
    				# @ 加载数据,建立分类器
                    with open(classifier_filename_exp, 'rb') as infile:
                        (model, class_names) = pickle.load(infile)
    
                    print('Loaded classifier model from file "%s"' % classifier_filename_exp)
    
    				# @ 预测,标签结果应该是 one_hot 的。
                    predictions = model.predict_proba(emb_array)
                    best_class_indices = np.argmax(predictions, axis=1) # @ 输出每列最大的序号。
                    best_class_probabilities = predictions[np.arange(len(best_class_indices)), best_class_indices]
                    
                    for i in range(len(best_class_indices)):
                        print('%4d  %s: %.3f' % (i, class_names[best_class_indices[i]], best_class_probabilities[i]))
                        
    				# @ 评估结果。labels 是测试集的实际结果,best_class_indices是预测结果。
                    accuracy = np.mean(np.equal(best_class_indices, labels))
                    print('Accuracy: %.3f' % accuracy)
                    
    # @ 将数据集分成训练集和测试集
    def split_dataset(dataset, min_nrof_images_per_class, nrof_train_images_per_class):
        train_set = []
        test_set = []
        for cls in dataset:
            paths = cls.image_paths
            # Remove classes with less than min_nrof_images_per_class
            if len(paths)>=min_nrof_images_per_class:
                np.random.shuffle(paths)
                train_set.append(facenet.ImageClass(cls.name, paths[:nrof_train_images_per_class]))
                test_set.append(facenet.ImageClass(cls.name, paths[nrof_train_images_per_class:]))
        return train_set, test_set
    	
    # @ 命令行参数,使用的系统库 argparse
    # @ ** 写法值得记住 **
    def parse_arguments(argv):
        parser = argparse.ArgumentParser()
        
        parser.add_argument('mode', type=str, choices=['TRAIN', 'CLASSIFY'],
            help='Indicates if a new classifier should be trained or a classification ' + 
            'model should be used for classification', default='CLASSIFY')
        parser.add_argument('data_dir', type=str,
            help='Path to the data directory containing aligned LFW face patches.')
        parser.add_argument('model', type=str, 
            help='Could be either a directory containing the meta_file and ckpt_file or a model protobuf (.pb) file')
        parser.add_argument('classifier_filename', 
            help='Classifier model file name as a pickle (.pkl) file. ' + 
            'For training this is the output and for classification this is an input.')
        parser.add_argument('--use_split_dataset', 
            help='Indicates that the dataset specified by data_dir should be split into a training and test set. ' +  
            'Otherwise a separate test set can be specified using the test_data_dir option.', action='store_true')
        parser.add_argument('--test_data_dir', type=str,
            help='Path to the test data directory containing aligned images used for testing.')
        parser.add_argument('--batch_size', type=int,
            help='Number of images to process in a batch.', default=90)
        parser.add_argument('--image_size', type=int,
            help='Image size (height, width) in pixels.', default=160)
        parser.add_argument('--seed', type=int,
            help='Random seed.', default=666)
        parser.add_argument('--min_nrof_images_per_class', type=int,
            help='Only include classes with at least this number of images in the dataset', default=20)
        parser.add_argument('--nrof_train_images_per_class', type=int,
            help='Use this number of images from each class for training and the rest for testing', default=10)
        
        return parser.parse_args(argv)
    
    # @ 主函数
    # @ sys.argv[1:] 就是命令行输入的 classify.py 后面的所有字符串,以空格分隔。
    if __name__ == '__main__':
        main(parse_arguments(sys.argv[1:]))
    

      

  • 相关阅读:
    php5.3连接sqlserver2005
    U盘文件名称变成乱码的解决方法
    sql小计汇总 rollup用法实例分析(转)
    关于document.all.item遇到IE8时无法正常取到数据
    jQuery 库中的 $() 是什么?
    JavaScript内置可用类型
    jquery中$.get()提交和$.post()提交有区别吗?
    什么是CDN?哪些是流行的jQuery CDN?使用CDN有什么好处?
    说一说Servlet的生命周期?
    request.getAttribute()和 request.getParameter()有何区别?
  • 原文地址:https://www.cnblogs.com/alexYuin/p/8886727.html
Copyright © 2011-2022 走看看