zoukankan      html  css  js  c++  java
  • caffe神经网络框架的辅助工具(将图片转换为leveldb格式)


    caffe中负责整个网络输入的datalayer是从leveldb里读取数据的,是一个google实现的很高效的kv数据库。

    因此我们训练网络必须先把数据转成leveldb的格式。

    这里我实现的是把一个目录的全部图片转成leveldb的格式。


    工具使用命令格格式:convert_imagedata src_dir dst_dir attach_dir channel width height

    例子:./convert_imagedata.bin /home/linger/imdata/collar_train/ /home/linger/linger/testfile/crop_train_db/ /home/linger/linger/testfile/crop_train_attachment/ 3 50 50


    源码:

    #include <google/protobuf/text_format.h>
    #include <glog/logging.h>
    #include <leveldb/db.h>
    
    #include <stdint.h>
    #include <fstream>  // NOLINT(readability/streams)
    #include <string>
    #include <set>
    #include <stdio.h>
    #include <string.h>
    #include <stdlib.h>
    #include <dirent.h>
    #include <sys/stat.h>
    #include <unistd.h>
    #include <sys/types.h>
    #include "caffe/proto/caffe.pb.h"
    #include <opencv2/highgui/highgui.hpp>
    #include <opencv2/highgui/highgui_c.h>
    #include <opencv2/imgproc/imgproc.hpp>
    
    using std::string;
    using namespace std;
    
    
    set<string> all_class_name;
    map<string,int> class2id;
    
    
    /**
     * path:文件夹
     * files:用于保存文件名称的vector
     * r:是否须要遍历子文件夹
     * return:文件名称,不包括路径
     */
    void list_dir(const char *path,vector<string>& files,bool r = false)
    {
    	DIR *pDir;
    	struct dirent *ent;
    	char childpath[512];
    	pDir = opendir(path);
    	memset(childpath, 0, sizeof(childpath));
    	while ((ent = readdir(pDir)) != NULL)
    	{
    		if (ent->d_type & DT_DIR)
    		{
    
    			if (strcmp(ent->d_name, ".") == 0 || strcmp(ent->d_name, "..") == 0)
    			{
    				continue;
    			}
    			if(r) //假设须要遍历子文件夹
    			{
    				sprintf(childpath, "%s/%s", path, ent->d_name);
    				list_dir(childpath,files);
    			}
    		}
    		else
    		{
    			files.push_back(ent->d_name);
    		}
    	}
    	sort(files.begin(),files.end());//排序
    
    }
    
    string get_classname(string path)
    {
    	int index = path.find_last_of('_');
    	return path.substr(0, index);
    }
    
    
    int get_labelid(string fileName)
    {
    	string class_name_tmp = get_classname(fileName);
    	all_class_name.insert(class_name_tmp);
    	map<string,int>::iterator name_iter_tmp = class2id.find(class_name_tmp);
    	if (name_iter_tmp == class2id.end())
    	{
    		int id = class2id.size();
    		class2id.insert(name_iter_tmp, std::make_pair(class_name_tmp, id));
    		return id;
    	}
    	else
    	{
    		return name_iter_tmp->second;
    	}
    }
    
    void loadimg(string path,char* buffer)
    {
    	cv::Mat img = cv::imread(path, CV_LOAD_IMAGE_COLOR);
    	string val;
    	int rows = img.rows;
    	int cols = img.cols;
    	int pos=0;
    	for (int c = 0; c < 3; c++)
    	{
    		for (int row = 0; row < rows; row++)
    		{
    			for (int col = 0; col < cols; col++)
    			{
    				buffer[pos++]=img.at<cv::Vec3b>(row,col)[c];
    			}
    		}
    	}
    
    }
    void convert(string imgdir,string outputdb,string attachdir,int channel,int width,int height)
    {
    	leveldb::DB* db;
    	leveldb::Options options;
    	options.create_if_missing = true;
    	options.error_if_exists = true;
    	caffe::Datum datum;
    	datum.set_channels(channel);
    	datum.set_height(height);
    	datum.set_width(width);
    	int image_size = channel*width*height;
    	char buffer[image_size];
    
    	string value;
    	CHECK(leveldb::DB::Open(options, outputdb, &db).ok());
    	vector<string> filenames;
    	list_dir(imgdir.c_str(),filenames);
    	string img_log = attachdir+"image_filename";
    	ofstream writefile(img_log.c_str());
    	for(int i=0;i<filenames.size();i++)
    	{
    		string path= imgdir;
    		path.append(filenames[i]);//算出绝对路径
    
    		loadimg(path,buffer);
    
    		int labelid = get_labelid(filenames[i]);
    
    		datum.add_label(labelid);
    		datum.set_data(buffer,image_size);
    		datum.SerializeToString(&value);
    		snprintf(buffer, image_size, "%05d", i);
    		printf("
    classid:%d classname:%s abspath:%s",labelid,get_classname(filenames[i]).c_str(),path.c_str());
    		db->Put(leveldb::WriteOptions(),string(buffer),value);
    		//printf("%d %s
    ",i,fileNames[i].c_str());
    
    		assert(writefile.is_open());
    		writefile<<i<<" "<<filenames[i]<<"
    ";
    
    	}
    	delete db;
    	writefile.close();
    
    	img_log = attachdir+"image_classname";
    	writefile.open(img_log.c_str());
    	set<string>::iterator iter = all_class_name.begin();
    	while(iter != all_class_name.end())
    	{
    		assert(writefile.is_open());
    		writefile<<(*iter)<<"
    ";
    		//printf("%s
    ",(*iter).c_str());
    		iter++;
    	}
    	writefile.close();
    
    }
    
    int main(int argc, char** argv)
    {
    	if (argc < 6)
    	{
    	    LOG(ERROR) << "convert_imagedata src_dir dst_dir attach_dir channel width height";
    	    return 0;
    	}
    //./convert_imagedata.bin  /home/linger/imdata/collarTest/ /home/linger/linger/testfile/dbtest/  /home/linger/linger/testfile/test_attachment/ 3 250 250
    	//   ./convert_imagedata.bin /home/linger/imdata/collar_train/ /home/linger/linger/testfile/crop_train_db/ /home/linger/linger/testfile/crop_train_attachment/ 3 50 50
    	google::InitGoogleLogging(argv[0]);
    	string src_dir = argv[1];
    	string src_dst = argv[2];
    	string attach_dir = argv[3];
    	int channel = atoi(argv[4]);
    	int width = atoi(argv[5]);
    	int height = atoi(argv[6]);
    
    	//for test
    	/*
    	src_dir = "/home/linger/imdata/collarTest/";
    	src_dst = "/home/linger/linger/testfile/dbtest/";
    	attach_dir = "/home/linger/linger/testfile/";
    	channel = 3;
    	width = 250;
    	height = 250;
    	 */
    
    	convert(src_dir,src_dst,attach_dir,channel,width,height);
    
    
    
    }
    



  • 相关阅读:
    HTML5新标签与特性---多媒体
    HTML5新标签与特性---新表单+新属性----综合案例1
    字体图标引入到HTML---复制用代码
    字体图标网站---常用汇总
    滑动门出现的背景---实例微信导航栏(a盒子里面包span盒子,文字写在span里)
    【Web前端开发】---前端培训roadmap
    清除浮动的4种方法
    进度更新---Responsive Web Design Certification (300 hours)
    Python实现一个桌面版的翻译工具【新手必学】
    Python爬虫老是被封的解决方法【面试必问】
  • 原文地址:https://www.cnblogs.com/yxwkf/p/4003852.html
Copyright © 2011-2022 走看看