zoukankan      html  css  js  c++  java
  • 【Caffe代码解析】compute_image_mean

    功能:
    计算训练数据库的平均图像。


    由于平均归一化训练图像会对结果有提升,所以Caffe里面,提供了一个可选项。

    用法:
    compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE] ”)
    參数:INPUT_DB: 数据库
    參数(可选):OUTPUT_FILE: 输出文件名称,不提供的话,不保存平均图像blob

    实现方法:

    数据源:求平均图像的方法是直接从数据库(LevelDB或者LMDB)里面直接读取出来的,而不是直接用图像数据库里面求出,意味着,必须先进行图像到数据库的转换后,才干求平均图像这一步。

    接下来就是遍历KV数据库的每个值while (cursor->valid()) 将每个数据值转换为Datum,datum.ParseFromString(cursor->value());

    接着将Datum阶码到sum_blob 中。sum_blob 是一个num=1,channels=图像.channel,height=图像.height ,width=图像.width 的blob

    累加:

    sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);

    最后求平均:

    sum_blob.set_data(i, sum_blob.data(i) / count);

    存在的问题:上述代码仅仅是先累加在处于数目求和,显然,假设须要求平均的图像的数目相当多的话,就有可能溢出(浮点溢出)。

    最后,假设要求简单一点的话,也能够直接求每个通道的平均值。
    源码://2015.06.04版本号

    #include <stdint.h>
    #include <algorithm>
    #include <string>
    #include <utility>
    #include <vector>
    
    #include "boost/scoped_ptr.hpp"
    #include "gflags/gflags.h"
    #include "glog/logging.h"
    
    #include "caffe/proto/caffe.pb.h"
    #include "caffe/util/db.hpp"
    #include "caffe/util/io.hpp"
    
    using namespace caffe;  // NOLINT(build/namespaces)
    
    using std::max;
    using std::pair;
    using boost::scoped_ptr;
    
    DEFINE_string(backend, "lmdb",
            "The backend {leveldb, lmdb} containing the images");
    
    int main(int argc, char** argv) {
      ::google::InitGoogleLogging(argv[0]);
    
    #ifndef GFLAGS_GFLAGS_H_
      namespace gflags = google;
    #endif
    
      gflags::SetUsageMessage("Compute the mean_image of a set of images given by"
            " a leveldb/lmdb
    "
            "Usage:
    "
            "    compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]
    ");
    
      gflags::ParseCommandLineFlags(&argc, &argv, true);
    
      if (argc < 2 || argc > 3) {
        gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/compute_image_mean");
        return 1;
      }
    
      scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
      db->Open(argv[1], db::READ);
      scoped_ptr<db::Cursor> cursor(db->NewCursor());
    
      BlobProto sum_blob;
      int count = 0;
      // load first datum
      Datum datum;
      datum.ParseFromString(cursor->value());
    
      if (DecodeDatumNative(&datum)) {
        LOG(INFO) << "Decoding Datum";
      }
    
      sum_blob.set_num(1);
      sum_blob.set_channels(datum.channels());
      sum_blob.set_height(datum.height());
      sum_blob.set_width(datum.width());
      const int data_size = datum.channels() * datum.height() * datum.width();
      int size_in_datum = std::max<int>(datum.data().size(),
                                        datum.float_data_size());
      for (int i = 0; i < size_in_datum; ++i) {
        sum_blob.add_data(0.);
      }
      LOG(INFO) << "Starting Iteration";
      while (cursor->valid()) {
        Datum datum;
        datum.ParseFromString(cursor->value());
        DecodeDatumNative(&datum);
    
        const std::string& data = datum.data();
        size_in_datum = std::max<int>(datum.data().size(),
            datum.float_data_size());
        CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
            size_in_datum;
        if (data.size() != 0) {
          CHECK_EQ(data.size(), size_in_datum);
          for (int i = 0; i < size_in_datum; ++i) {
            sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
          }
        } else {
          CHECK_EQ(datum.float_data_size(), size_in_datum);
          for (int i = 0; i < size_in_datum; ++i) {
            sum_blob.set_data(i, sum_blob.data(i) +
                static_cast<float>(datum.float_data(i)));
          }
        }
        ++count;
        if (count % 10000 == 0) {
          LOG(INFO) << "Processed " << count << " files.";
        }
        cursor->Next();
      }
    
      if (count % 10000 != 0) {
        LOG(INFO) << "Processed " << count << " files.";
      }
      for (int i = 0; i < sum_blob.data_size(); ++i) {
        sum_blob.set_data(i, sum_blob.data(i) / count);
      }
      // Write to disk
      if (argc == 3) {
        LOG(INFO) << "Write to " << argv[2];
        WriteProtoToBinaryFile(sum_blob, argv[2]);
      }
      const int channels = sum_blob.channels();
      const int dim = sum_blob.height() * sum_blob.width();
      std::vector<float> mean_values(channels, 0.0);
      LOG(INFO) << "Number of channels: " << channels;
      for (int c = 0; c < channels; ++c) {
        for (int i = 0; i < dim; ++i) {
          mean_values[c] += sum_blob.data(dim * c + i);
        }
        LOG(INFO) << "mean_value channel [" << c << "]:" << mean_values[c] / dim;
      }
      return 0;
    }
    
  • 相关阅读:
    No module named _tkinter
    Camera2与TextureView使用
    Collections常用方法总结
    Android插件化框架
    《战狼2》观后感——民族荣耀
    《茶马古道》观后感——朝圣之路
    点击查看大图Activity
    图片压缩代码
    《天那边》观后感——对一些现象的反思
    recyclerView的使用
  • 原文地址:https://www.cnblogs.com/yfceshi/p/7225701.html
Copyright © 2011-2022 走看看