zoukankan      html  css  js  c++  java
  • 实习日记:图像检索算法 LSH 的总结与分析

    先贴上这两天刚出炉的C++代码。(利用 STL 偷了不少功夫,代码待优化)

    Head.h

     1 #ifndef HEAD_H
     2 #define HEAD_H
     3 
     4 #include "D:\LiYangGuang\VSPRO\MYLSH\HashTable.h"
     5 
     6 
     7 #include <iostream>
     8 #include <fstream>
     9 #include <time.h>
    10 #include <cstdlib>
    11 #include <vector>
    12 #include <map>
    13 #include <set>
    14 #include <string>
    15 
    16 using namespace std;
    17 
    18 
    19 void loadData(bool (*data)[128], int n, char *filename);
    20 void createTable(HashTable HTSet[], bool data[][128], bool extDat[][n][k] );
    21 void insert(HT HTSet[], bool (*extDat)[n][k]);
    22 void standHash(HT HTSet[]);
    23 void search(vector<int>& record, bool query[128], HT HTSet[]);
    24 /*int getPosition(int V[], std::string s, int N);*/
    25 
    26 #endif

    HashTable.h

    #include <string>
    #include <vector>
    
    enum{ k = 15, l = 1, n = 587329, M = n};
    
    typedef struct 
    {
        std::string key;
        std::vector<int> elem; // element's index
    } bucket; 
    
    struct INT 
    {
        bool used;
        int val;
        struct INT * next;
        INT() : used(false), val(0), next(NULL){}
    };
    
    typedef struct HashTable 
    {
        int R[k];          // k random dimensions
        int RNum[k];   //  random numbers little than M
        //string DC;          // the contents of k dimensions 
        std::vector<bucket> BukSet;
        INT Hash2[M];
    } HT;

    getPosition.h

    #include <string>
    inline int getPosition(int V[], std::string s, int N)
    {
    	int position = 0;
    	for(int col = 0; col < k; ++col)
    	{
    		position += V[col] * (s[col] - '0');
    		position %= M;
    	}
    	return position;
    }
    

     computeDistance.h

    inline int distance(bool v1[], bool v2[], int N)
    {
    	int d = 0;
    	for(int i = 0; i < N; ++i)
    		d += v1[i] ^ v2[i];
    
    	return d;
    
    }
    

     main.cpp

    #include "Head.h"
    #include "D:\LiYangGuang\VSPRO\MYLSH\computeDistance.h"
    using namespace std;
    // length of sub hashtable, as well the number of elements.
    const int MAX_Q = 1000; 
    
    HT HTSet[l];
    
    bool data[n][128];
    bool extDat[l][n][k];
    
    bool query[MAX_Q][128]; // set the query item to 1000.
    
    int main(int argc, char *argv)
    {
    	/************************************************************************/
    	/*             Firstly, create the HashTables                           */
    	/************************************************************************/
    	char *filename = "D:\LiYangGuang\VSPRO\MYLSH\data.txt";
    	loadData(data, n, filename);
    	createTable(HTSet, data, extDat);
    	insert(HTSet,extDat);
    	standHash(HTSet);
    
    	/************************************************************************/
    	/*              Secondly, start the LSH search                          */
    	/************************************************************************/
    
    	char *queryFile = "D:\LiYangGuang\VSPRO\MYLSH\query.txt";
    	loadData(query, MAX_Q, queryFile);
    	clock_t time0 = clock();
    	for(int qId = 0; qId < MAX_Q; ++qId)
    	{
    		vector<int> record;
    		clock_t timeA = clock();
    		search(record, query[qId], HTSet);
    		set<int> Dis;
    		for(size_t i = 0; i < record.size(); ++i)
    			Dis.insert(distance(data[record[i]], query[qId]));
    		clock_t timeB = clock();
    		cout << "第 " << qId + 1 << " 次查询时间:" << timeB - timeA << endl;
    	}
    	clock_t time1 = clock();
    	cout << "总查询时间:" << time1 - time0 << endl;
    
    
        return 0;
    
    }
    

     loadData.cpp

    #include <string>
    #include <fstream>
    
    void loadData(bool (*data)[128], int n, char* filename)
    {
    	std::ifstream ifs;
    	ifs.open(filename, std::ios::in);
    	for(int row = 0; row < n; ++row)
    	{
    		std::string line;
    		getline(ifs, line);
    		for(int col = 0; col < 128; ++col)
    			data[row][col] = (line[col] - '0') & 1;
    	/*	std::cout << row << std::endl;*/
    
    	}
    	ifs.close();
    }
    

     creatTable.cpp

    #include "HashTable.h"
    #include <ctime>
    
    void createTable(HT HTSet[], bool data[][128], bool extDat[][n][k] )
    {
    	srand((unsigned)time(NULL));
    	for(int tableNum = 0; tableNum < l; ++tableNum)  
    	{      /*	creat the ith Table;*/
    
    		for(int randNum = 0; randNum < k; ++randNum)
    		{
    			HTSet[tableNum].R[randNum] = rand() % 128;
    			HTSet[tableNum].RNum[randNum] = rand() % M;
    
    			for(int item = 0; item < n; ++item)
    			{
    				extDat[tableNum][item][randNum] = 
    					data[item][HTSet[tableNum].R[randNum]];
    			}
    		}
    	}
    }
    

    insertData.cpp

    #include "HashTable.h"
    #include <iostream>
    #include <map>
    using namespace std;
    
    map<string, int> deRepeat;
    bool equal(bool V[], bool V2[], int n)
    {
    	int i = 0;
    	while(i < n)
    	{
    		if(V[i] != V2[i])
    			return false;
    	}
    	return true;
    }
    
    string itoa(bool *v, int n, string s)
    {
    	for(int i = 0; i < n; ++i)
    		s.push_back(v[i]+'0');
    	return s;
    }
    
    void insert(HT HTSet[], bool (*extDat)[n][k])
    {
    	for(int t = 0; t < l; ++ t) /* t: table */
    	{
    		int bktNum = 0;
    		bucket bkt;
    		bkt.key = string(itoa(extDat[t][0], k, string("")));
    		bkt.elem.push_back(0);
    		HTSet[t].BukSet.push_back(bkt);
    		deRepeat.insert(make_pair(bkt.key, bktNum++)); // 0 为 bucket 的位置
    		for(int item = 1; item < n; ++item)
    		{
    			cout << item << endl;
    			string key = itoa(extDat[t][item], k, string(""));
    			//map<string, int>::iterator it = deRepeat.find(key);
    			if(deRepeat.find(key) != deRepeat.end())
    			{
    				HTSet[t].BukSet[deRepeat.find(key)->second].elem.push_back(item);
    				cout << "exist" << endl;
    			}
    			else{
    				bucket bkt2;
    				bkt2.key = key;
    				bkt2.elem.push_back(item);
    				HTSet[t].BukSet.push_back(bkt2);
    				deRepeat.insert(make_pair(bkt2.key, bktNum++));
    				cout << "creat" << endl;
    			}
    		}
    		deRepeat.clear();
    	}
    }
    

     standHash.cpp

    #include "HashTable.h"
    #include <iostream>
    #include "getPosition.h"
    
    void standHash(HT HTSet[])
    {
    	for(int t = 0; t < l; ++t)
    	{
    		int BktLen = HTSet[t].BukSet.size();
    		for(int b = 0; b < BktLen; ++b)
    		{
    			int position = getPosition(HTSet[t].RNum, HTSet[t].BukSet[b].key, k);
    			INT *pIn = &HTSet[t].Hash2[position];
    			while(pIn->used && pIn->next != NULL)
    				pIn = pIn->next;
    			if(pIn->used){
    				pIn->next = new INT;
    				pIn->next->val = b;
    				pIn->next->used = true;
    			}else{
    				pIn->val = b;
    				pIn->used = true;
    			}
    		}
    		std::cout << "the " << t << "th HashTable has been finished." << std::endl;
    	}
    }
    

     search.cpp

    #include "HashTable.h"
    #include "getPosition.h"
    #include <vector>
    using namespace std;
    
    void search(vector<int>& record, bool query[128], HT HTSet[])
    {
    	for(int t = 0; t < l; ++t)
    	{
    		string temKey;
    		int temPos = 0;
    		for(int c = 0; c < k; ++c)
    			temKey.push_back(query[HTSet[t].R[c]] + '0');
    		temPos = getPosition(HTSet[t].RNum, temKey, k);
    		vector<int> bktId;
    		INT *p = &HTSet[t].Hash2[temPos];
    		while(p != NULL && p->used)
    		{
    			bktId.push_back(p->val);
    			p = p->next;
    		}
    		for(size_t i = 0; i < bktId.size(); ++i)
    		{
    			bucket temB = HTSet[t].BukSet[bktId[i]];
    			if(temKey == temB.key)
    			{
    				for(size_t j = 0; j < temB.elem.size(); ++j)
    					record.push_back(temB.elem[j]);
    			}
    		}
    	}
    }
    

     

     稍后总结。

    代码调整:

    main.cpp

    #include "Head.h"
    #include "D:\LiYangGuang\VSPRO\MYLSH\MYLSH\computeDistance.h"
    using namespace std;
    #pragma warning(disable: 4996)
    // length of sub hashtable, as well the number of elements.
    const int MAX_Q = 1000; 
    
    HT HTSet[l];
    
    bool data[n][128];
    bool extDat[l][n][k];
    
    bool query[MAX_Q][128]; // set the query item to 1000.
    
    void getFileName(int v, char *FileName)
    {
    	itoa(v, FileName, 10);
    	strcat(FileName, ".txt");
    }
    
    
    
    int main(int argc, char *argv)
    {
    	/************************************************************************/
    	/*             Firstly, create the HashTables                           */
    	/************************************************************************/
    	char *filename = "D:\LiYangGuang\VSPRO\MYLSH\data.txt";
    	loadData(data, n, filename);
    	createTable(HTSet, data, extDat);
    	insert(HTSet,extDat);
    	standHash(HTSet);
    
    	char *queryFile = "D:\LiYangGuang\VSPRO\MYLSH\query.txt";
    	loadData(query, MAX_Q, queryFile);
    	/************************************************************************/
    	/*               Secondly, start the linear Search                       */
    // 	/************************************************************************/
    // 
    // 	vector<RECORD> record2;
    // 	clock_t LineTime1 = clock();
    // 	for(int qId = 0; qId < MAX_Q; ++qId)
    // 	{
    // 		for(int i = 0; i < n; ++i)
    // 		{
    // 			RECORD tem;
    // 			tem.Id = i;
    // 			tem.Dis = distance(data[i], query[qId]);
    // 			record2.push_back(tem);
    // 		}
    // 		record2.clear();
    // 	}
    // 	clock_t LineTime2 = clock();
    // 	float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC;
    // 	cout << "全部线性查询时间:" << LineTime << " s," << " 合"
    // 		<< LineTime / 60 << " minutes."<< endl;
    // 
    // 	/************************************************************************/
    // 	/*              Thirdly, start the LSH search                          */
    // 	/************************************************************************/
    // 
    // 	clock_t time0 = clock();
    // 	ofstream ofs;
    // 	char outFileName[10] = { ''};
    // 	int K = 1; /// define KNN
    // 	getFileName(K, outFileName);
    // 	ofs.out(outFileName);
    // 
    // 	for(int qId = 0; qId < MAX_Q; ++qId)
    // 	{
    // 		vector<RECORD> record;
    // 		clock_t timeA = clock();
    // 		search(record, query[qId], HTSet, data);
    // 		if(getkNN(record,K))
    // 		clock_t timeB = clock();
    // 		record.clear();
    // 		cout << "第 " << qId + 1 << " 次查询时间:" << 
    // 			(float)(timeB - timeA) / CLOCKS_PER_SEC << " s" << endl;
    // 	}
    // 	clock_t time1 = clock();
    // 	cout << "总查询时间:" << (float)(time1 - time0) / CLOCKS_PER_SEC 
    // 		<< " s." << endl;
    /************************************************************************/
    /*                                                                      */
    /************************************************************************/
    	ofstream ofs;
    	char outFileName[10] = { ''};
    	int K = 1; /// define KNN
    	getFileName(K, outFileName);
    	ofs.open(outFileName, ios::out);
    	//ofs.precision(3);
    	float TotalLinearTime, TotalLSHTime;
    	TotalLinearTime = TotalLSHTime = 0;
    
    	float TotalError = 0;
    	int TotalMiss = 0;
    
    
    	vector<RECORD> record2;
    	for(int qId = 0; qId < MAX_Q; ++qId)
    	{
    		cout << "第 " << qId << " 次查询" << endl;
    		clock_t LineTime1 = clock();
    		for(int i = 0; i < n; ++i)
    		{
    			RECORD tem;
    			tem.Id = i;
    			tem.Dis = computeDistance(data[i], query[qId], 128);
    			record2.push_back(tem);
    		}
    	   getkNN(record2); // 利用其对距离排序
    	   clock_t LineTime2 = clock();
    	   float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC;
    	   TotalLinearTime += LineTime;
    
    	/************************************************************************/
    	/*              Thirdly, start the LSH search                          */
    	/************************************************************************/
    
    		vector<RECORD> record;
    		clock_t timeA = clock();
    		search(record, query[qId], HTSet, data);
    		if(!getkNN(record, K)) 
    		{
    			float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC;
    			TotalLSHTime += queryTime;
    			ofs << "Miss	" << "LSH Time: " << queryTime 
    				<< "s	Linear time: " << LineTime << 's' << endl;
    			TotalMiss += 1;
    		}
    		else{
    			float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC;
    			TotalLSHTime += queryTime;
    			float error = 0;
    			if(record[K-1].Dis == 0)
    				error = 1;
    			else
    				error = (float)record2[K-1].Dis / record[K-1].Dis;
    			ofs << "Error: " << error << "	LSH Time: " 
    				<< queryTime << "s	Linear time: " << LineTime << 's' << endl;
    			TotalError += error;
    
    		}
    		record.clear();
    		record2.clear();
    	}
    	ofs << "Average errror: " << TotalError / 817 << endl;//recitfy
    	ofs << "Miss ratio: " << TotalMiss / MAX_Q << endl;
    	ofs << "Total query time: " << "LSH, " << TotalLSHTime / 3600 << " h; "
    		<< "Linear, " << TotalLinearTime / 3600 << " h." << endl;
    	ofs.close();
    
    
    	return 0;
    
    }
    

     computeDistance.h

    inline int computeDistance(bool v1[], bool v2[], int N)
    {
    	int d = 0;
    	for(int i = 0; i < N; ++i)
    		d += v1[i] ^ v2[i];
    
    	return d;
    
    }
    

     Search.cpp

    #include "HashTable.h"
    #include "getPosition.h"
    #include "computeDistance.h"
    #include <vector>
    using namespace std;
    
    /***    加入 data 项是为了计算距离  ***/
    void search(vector<RECORD>& record, bool query[128], HT HTSet[], bool data[][128])
    {
    	for(int t = 0; t < l; ++t)
    	{
    		string temKey;
    		int temPos = 0;
    		for(int c = 0; c < k; ++c)
    			temKey.push_back(query[HTSet[t].R[c]] + '0');
    		temPos = getPosition(HTSet[t].RNum, temKey, k);
    		vector<int> bktId;
    		INT *p = &HTSet[t].Hash2[temPos];
    		while(p != NULL && p->used)
    		{
    			bktId.push_back(p->val);
    			p = p->next;
    		}
    		for(size_t i = 0; i < bktId.size(); ++i)
    		{
    			bucket temB = HTSet[t].BukSet[bktId[i]];
    			if(temKey == temB.key)
    			{
    				for(size_t j = 0; j < temB.elem.size(); ++j)
    				{
    					RECORD temp;
    					temp.Id = temB.elem[j];
    					temp.Dis = computeDistance(data[temp.Id], query, 128);
    					record.push_back(temp);
    				}
    					
    			}
    		}
    	}
    }
    

     

    相关截图:

  • 相关阅读:
    SQL基础学习_03_数据更新
    SQL基础学习_02_查询
    SQL基础学习_01_数据库和表
    HCA数据下载
    Multiclonal Invasion in Breast Tumors Identified by Topographic Single Cell Sequencing
    gg_pie
    ggnetwork
    ggplot2画简单的heatmap
    简单R语言爬虫
    突变数据清洗
  • 原文地址:https://www.cnblogs.com/liyangguang1988/p/3875998.html
Copyright © 2011-2022 走看看