zoukankan      html  css  js  c++  java
  • 使用LAP数据集进行年龄训练及估计

    一、背景

    原本是打算按《DEX Deep EXpectation of apparent age from a single image》进行表面年龄的训练,可由于IMDB-WIKI的数据集比较庞大,各个年龄段分布不均匀,难以划分训练集及验证集。后来为了先跑通整个训练过程的主要部分,就直接用LAP数据集,参考caffe的finetune_flickr_style,进行一些参数修改,利用bvlc_reference_caffenet.caffemodel完成年龄估计的finetune。

    二、训练数据集准备

    1、下载LAP数据集,包括Train、Validation、Test,以及对应的年龄label,http://chalearnlap.cvc.uab.es/dataset/18/description/,需要注册。也可以从我的网盘下载:

    链接:https://pan.baidu.com/s/1536TgbR_cCcS7-QHfEAeMw

    提取码:xc45

    2、将标注好的csv文件转换为caffe识别的txt格式。csv每一行的信息为:图片名、年龄、标准差。训练的时候不需要标准差信息,我们只要将图片名和年龄写入到txt中,并按空格隔开,得到Train.txt如下:

    image

    同样,完成验证集cvs文件的转换,得到Validation.txt。

    三、模型及相关文件拷贝

    1、拷贝预训练好的vgg16模型caffemodelsvlc_reference_caffenetvlc_reference_caffenet.caffemodel至工作目录下,该文件约232M;

    2、拷贝caffemodelsfinetune_flickr_style文件夹中deploy.prototxt、solver.prototxt、train_val.prototxt至工作目录下;

    3、拷贝imageNet的均值文件caffedatailsvrc12imagenet_mean.binaryproto至工作目录下。

    四、参数修改

    1、修改train_val.prototxt

    image

    以及最后的输出层个数,因为我们要训练的为[0,100]岁的输出,共101类,所以:

    image

    2、修改solver.protxt

    image

    3、修改用于实际测试的部署文件deploy.protxt

    image

    输出层的个数也要改:

    image

    五、开始训练

    1、新建train.bat

    caffe train -solver solver.prototxt -weights bvlc_reference_caffenet.caffemodel
    rem caffe train --solver solver.prototxt --snapshot snapshot/bvlc_iter_48000.solverstate
    pause

    双击即可开始训练,当训练过程中出现意外中断,可注释第一行,关闭第二行注释,根据实际情况修改保存,继续双击训练。

    我的电脑CPU是i5 6500,显卡为gtx1050Ti,8G内存,大致要训练10个小时吧,中途也出现了一些内存不足训练终止的情况。

    2、训练结束

    QQ截图20181005101028-lap2_2

    六、模型评价

    年龄估计原本是一个线性问题,不是一个明确的分类问题,人都无法准确无误地得到某人的年龄,更何况是机器呢。所以评价这个年龄分类模型的好坏不能简单地通过精度来衡量,可以用MAE(平均绝对误差)以及ε-error来衡量,其中

    image

    1、对验证集Validation.txt的所有图片进行预测

    借助 https://github.com/eveningglow/age-and-gender-classification ,其环境搭建可参考https://www.cnblogs.com/smbx-ztbz/p/9399016.html

    修改main函数

    int split(std::string str, std::string pattern, std::vector<std::string> &words)
    {
        std::string::size_type pos;
        std::string word;
        int num = 0;
        str += pattern;
        std::string::size_type size = str.size();
        for (auto i = 0; i < size; i++) {
            pos = str.find(pattern, i);
            if (pos == i) {
                continue;//if first string is pattern
            }
            if (pos < size) {
                word = str.substr(i, pos - i);
                words.push_back(word);
                i = pos + pattern.size() - 1;
                num++;
            }
        }
        return num;
    }
    
    //param example: model/deploy_age2.prototxt model/age_net.caffemodel model/mean.binaryproto img/0008.jpg
    int main(int argc, char* argv[])
    {
        if (argc != 5)
        {
            cout << "Command shoud be like ..." << endl;
            cout << "AgeAndGenderClassification ";
            cout << " "AGE_NET_MODEL_FILE_PATH" "AGE_NET_WEIGHT_FILE_PATH" "MEAN_FILE_PATH" "TEST_IMAGE" " << endl;
            std::cout << "argc = " << argc << std::endl;
            getchar();
            return 0;
        }
    
        // Get each file path
        string age_model(argv[1]);
        string age_weight(argv[2]);
        string mean_file(argv[3]);
        //string test_image(argv[4]);
    
        // Probability vector
        vector<Dtype> prob_age_vec;
    
        // Set mode
        Caffe::set_mode(Caffe::GPU);
    
        // Make AgeNet
        AgeNet age_net(age_model, age_weight, mean_file);
    
        // Initiailize both nets
        age_net.initNetwork();
    
        //读取待测试的图片名
        std::ifstream fin("E:\caffe\DEX_age_gender_predict\lap2\Validation.txt");
        std::string line;
        std::vector<std::string> test_images;
        std::vector<int> test_images_age;
        while (!fin.eof()) {
            std::getline(fin, line);
            std::vector<std::string> words;
            split(line, " ", words);
            test_images.push_back(words[0]);
            test_images_age.push_back(atoi(words[1].c_str()));
        }
        std::cout << "test_images size = " << test_images.size() << std::endl;
    
        std::ofstream fout("E:\caffe\DEX_age_gender_predict\lap2\Validation_predict.txt");
        for (int k = 0; k < test_images.size(); ++k) {
            std::cout << "k = " << k << std::endl;
            std::string test_image;
            test_image = test_images[k];
    
            // Classify and get probabilities
            Mat test_img = imread(test_image, CV_LOAD_IMAGE_COLOR);
            int age = age_net.classify(test_img, prob_age_vec);
    
            // Print result and show image
            //std::cout << "prob_age_vec size = " << prob_age_vec.size() << std::endl;
            //for (int i = 0; i < prob_age_vec.size(); ++i) {
            //    std::cout << "[" << i << "] = " << prob_age_vec[i] << std::endl;
            //}
    
            //Dtype prob;
            //int index;
            //get_max_value(prob_age_vec, prob, index);
            //std::cout << "prob = " << prob << ", index = " << index << std::endl;
    
            //imshow("AgeAndGender", test_img);
            //waitKey(0);
            fout << test_images[k] << " " << test_images_age[k] << " " << age << std::endl;
    
    
        }
    
        std::cout << "finish!" << std::endl;
        getchar();
        return 0;
    }

    我的命令参数为:E:caffeDEX_age_gender_predictlap2deploy.prototxt E:caffeDEX_age_gender_predictlap2snapshotvlc_iter_50000.caffemodel modelmean.binaryproto img008.jpg

    可根据实际情况修改。可得到Validation_predict.txt文件。运行过程中可能会因为内存不足中断运行,可能要分批次运行多次。

    2、计算MAE及ε-error

    (1)将Validation_predict.txt文件及验证集的标注文件Reference.csv拷贝到新建的vs项目的工作目录下;

    (2)计算

    #include <iostream>
    #include <string>
    #include <fstream>
    #include <vector>
    
    int split(std::string str, std::string pattern, std::vector<std::string> &words)
    {
        std::string::size_type pos;
        std::string word;
        int num = 0;
        str += pattern;
        std::string::size_type size = str.size();
        for (auto i = 0; i < size; i++) {
            pos = str.find(pattern, i);
            if (pos == i) {
                continue;//if first string is pattern
            }
            if (pos < size) {
                word = str.substr(i, pos - i);
                words.push_back(word);
                i = pos + pattern.size() - 1;
                num++;
            }
        }
        return num;
    }
    
    int main(int argc, char** argv)
    {
        //u, sigma, x
        std::vector<int> u;
        std::vector<float> sigma;
        std::vector<int> predict;
    
        std::string line;
        std::ifstream csv_file("Reference.csv");
        while (!csv_file.eof()) {
            std::getline(csv_file, line);
            std::vector<std::string> words;
            split(line, ";", words);
            u.push_back(atoi(words[1].c_str()));
            sigma.push_back(atof(words[2].c_str()));
        }
        std::ifstream predict_file("Validation_predict.txt");
        while (!predict_file.eof()) {
            std::getline(predict_file, line);
            std::vector<std::string> words;
            split(line, " ", words);
            predict.push_back(atoi(words[2].c_str()));
        }
        if (u.size() != predict.size()) {
            std::cout << "u.size() != predict.size()" << std::endl;
            getchar();
            return -1;
        }
    
        //MAE
        int sum_err = 0;
        float MAE = 0;
        for (int i = 0; i < u.size(); ++i) {
            sum_err += abs(u[i] - predict[i]);
        }
        MAE = static_cast<float>(sum_err) / u.size();
        std::cout << "MAE = " << MAE << std::endl;//11.7184
    
        //esro-error
        std::vector<float> errors;
        float err = 0;
        float error = 0.0;
        for (int i = 0; i < u.size(); ++i) {
            err = 1.0 - exp(-1.0*(predict[i] - u[i])*(predict[i] - u[i]) / (2 * sigma[i] * sigma[i]));
            errors.push_back(err);
            error += err;
        }
        error /= errors.size();
        std::cout << "error = " << error << std::endl;//0.682652
        
    
        std::cout << "finish!" << std::endl;
        getchar();
        return 0;
    }

    最终得到MAE为11.7184, ε-error为0.682652。

    七、实际应用中预测

    1、可利用caffe提供的classification工具对输入图片地进行估计

    classification deploy.prototxt snapshotvlc_iter_50000.caffemodel imagenet_mean.binaryproto ..age_labels.txt ..	est_image	est_3.jpg
    pause

    其中,age_labels.txt为0-100个label的说明信息,每个label对应一行,共101行,我的写法如下:

    image

    end

  • 相关阅读:
    SSH框架中使用Oracle数据库转换为SQLServer的相关配置和注意事项
    MYSQL性能优化系统整理
    PHP时间处理
    debian9 VirtualBox rc=-1908的错误
    https://snapcraft.io/store
    中文转拼音 pinyin4j的使用
    java对象转数组|数组转对象
    Deflater 压缩解压
    spring的RestTemplate连接池相关配置
    spring获取指定包下面的所有类
  • 原文地址:https://www.cnblogs.com/smbx-ztbz/p/9744970.html
Copyright © 2011-2022 走看看