zoukankan      html  css  js  c++  java
  • 转载 LibSVM文本分类之工程中调用LibSVM进行文本分类

    作者:finallyliuyu 转载使用等请注明出处

    首先介绍libsvm 中主要的文件svm.h,svm.c ,这个两个文件实现了svm的算法。 svm-train.c,svm-predict.c 分别完成训练和预测功能。

    本来我参照svm-train,svm-predict中的 main函数,将train功能,和predict功能直接在程序中整合,结果,调了一天都有异常。。(我还是太菜了)最后在同学的建议下 工程中改用系统调用的方式。为了获得准确率(将分类准确率输出到文本文件),将svm-predict函数做了如下修改:

    注意 accuracy_file部分对应的修改。

    void predict(FILE *input, FILE *output, FILE *accuracy_file)
    {
    	int correct = 0;
    	int total = 0;
    	double error = 0;
    	double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
    
    	int svm_type=svm_get_svm_type(model);
    	int nr_class=svm_get_nr_class(model);
    	double *prob_estimates=NULL;
    	int j;
    
    	if(predict_probability)
    	{
    		if (svm_type==NU_SVR || svm_type==EPSILON_SVR)
    			printf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
    		else
    		{
    			int *labels=(int *) malloc(nr_class*sizeof(int));
    			svm_get_labels(model,labels);
    			prob_estimates = (double *) malloc(nr_class*sizeof(double));
    			fprintf(output,"labels");		
    			for(j=0;j<nr_class;j++)
    				fprintf(output," %d",labels[j]);
    			fprintf(output,"\n");
    			free(labels);
    		}
    	}
    
    	max_line_len = 1024;
    	line = (char *)malloc(max_line_len*sizeof(char));
    	while(readline(input) != NULL)
    	{
    		int i = 0;
    		double target_label, predict_label;
    		char *idx, *val, *label, *endptr;
    		int inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
    
    		label = strtok(line," \t");
    		target_label = strtod(label,&endptr);
    		if(endptr == label)
    			exit_input_error(total+1);
    
    		while(1)
    		{
    			if(i>=max_nr_attr-1)	// need one more for index = -1
    			{
    				max_nr_attr *= 2;
    				x = (struct svm_node *) realloc(x,max_nr_attr*sizeof(struct svm_node));
    			}
    
    			idx = strtok(NULL,":");
    			val = strtok(NULL," \t");
    
    			if(val == NULL)
    				break;
    			errno = 0;
    			x[i].index = (int) strtol(idx,&endptr,10);
    			if(endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index)
    				exit_input_error(total+1);
    			else
    				inst_max_index = x[i].index;
    
    			errno = 0;
    			x[i].value = strtod(val,&endptr);
    			if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
    				exit_input_error(total+1);
    
    			++i;
    		}
    		x[i].index = -1;
    
    		if (predict_probability && (svm_type==C_SVC || svm_type==NU_SVC))
    		{
    			predict_label = svm_predict_probability(model,x,prob_estimates);
    			fprintf(output,"%g",predict_label);
    			for(j=0;j<nr_class;j++)
    				fprintf(output," %g",prob_estimates[j]);
    			fprintf(output,"\n");
    		}
    		else
    		{
    			predict_label = svm_predict(model,x);
    			fprintf(output,"%g\n",predict_label);
    		}
    
    		if(predict_label == target_label)
    			++correct;
    		error += (predict_label-target_label)*(predict_label-target_label);
    		sump += predict_label;
    		sumt += target_label;
    		sumpp += predict_label*predict_label;
    		sumtt += target_label*target_label;
    		sumpt += predict_label*target_label;
    		++total;
    	}
    	if (svm_type==NU_SVR || svm_type==EPSILON_SVR)
    	{
    		printf("Mean squared error = %g (regression)\n",error/total);
    		printf("Squared correlation coefficient = %g (regression)\n",
    		       ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
    		       ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
    		       );
    	}
    	else
    	{
    		float accuracy_rate = (float)correct/total*100;
    		fprintf(accuracy_file, "%f", accuracy_rate);
    
    		printf("Accuracy = %g%% (%d/%d) (classification)\n",
    		       (double)correct/total*100,correct,total);
    	}
    
    	if(predict_probability)
    		free(prob_estimates);
    }
    

      调用Libsvm完成分类,准确率计算的主程序。(我的代码)

    include "stdio.h"
    #include "stdlib.h"
    #include "memory.h"
    #include "string.h"
    
    #define MAX_COMMAND_LINE_LENGTH 2048
    
    int svm_train(char *command_path, char *train_libsvm, char *model_libsvm)
    {
    	// 生成命令行
    	char command_line[MAX_COMMAND_LINE_LENGTH] = {'\0'};
    	sprintf(command_line, "%s -t 0 %s %s", command_path, train_libsvm, model_libsvm);
    	// 执行命令行
    	system(command_line);
    	return 1;
    }
    int svm_predict(char *command_path, char *test_libsvm, char *model_libsvm, char *result_path, char *accuracy_path)
    {
    	// 生成命令行
    	char command_line[MAX_COMMAND_LINE_LENGTH] = {'\0'};
    	sprintf(command_line, "%s %s %s %s %s", command_path, test_libsvm, model_libsvm, result_path, accuracy_path);
    	// 执行命令行
    	system(command_line);
    	return 1;
    }
    int main()
    {
    	void AccuracyFormation();
    	int LibSvm();
    	int end;
    	AccuracyFormation();
    	//LibSvm();
    	
    	
    	
    	printf("finalfinish,congratulations!");
    	scanf("%d",&end);
    	return 1;
    	
    }
    
    //char command_line[MAX_COMMAND_LINE_LENGTH] = {'\0'};
    //// train
    //sprintf(command_line, "..\\Release\\svm_train.exe -t 0 D:\\libsvmdata\\500\\0\\100\\train.libsvm D:\\libsvmdata\\500\\0\\100\\model.libsvm");
    //system(command_line);
    
    //// predict
    ////command_line[0] = '\0';
    //memset(command_line, 0, sizeof(command_line[0])*MAX_COMMAND_LINE_LENGTH);
    //
    //sprintf(command_line, "..\\Release\\svm_predict.exe D:\\libsvmdata\\500\\0\\100\\test.libsvm D:\\libsvmdata\\500\\0\\100\\model.libsvm D:\\libsvmdata\\500\\0\\100\\result.txt D:\\libsvmdata\\500\\0\\100\\accuracy.txt");
    //system(command_line);
    int LibSvm()
    {
    	int vfold=5;
    	char *train_path_part="\\data\\train.libsvm";
    	char *test_path_part="\\data\\test.libsvm";
    	char *result_path_part="\\data\\result.txt";
    	char *model_path_part="\\data\\model.libsvm";
    	char *accuracy_path_part="\\data\\accuracy.txt";
        char  featureDimensions[15][20]={"10","20","30","40","50","60","70","80","90","100","110","120","130","140","150"};//特征维数
    	char done_research_times[5][10]={"0","1","2","3","4"};//已经进行了几次试验
    	char N_corpus[4][20]={"100","500","1000","1500"};//文档集规模
    	char command_path_train[] = "..\\Release\\svm_train.exe";
    	char command_path_predict[] = "..\\Release\\svm_predict.exe";
        /*char train_libsvm[] = "D:\\1_100\\TextCategorization_1_100_100\\data\\train.libsvm";
    	char test_libsvm[] = "D:\\1_100\\TextCategorization_1_100_100\\data\\test.libsvm";
    	char model_libsvm[] = "D:\\1_100\\TextCategorization_1_100_100\\data\\model.libsvm";
    	char result_path_1[] = "D:\\1_100\\TextCategorization_1_100_100\\data\\result.txt";
    	char accuracy_path_1[] = "D:\\1_100\\TextCategorization_1_100_100\\data\\accuracy.txt";*/
    
    
    	char file_address[300][5000];
    	char *temp=(char*) malloc(10000);
    	char *train_path=(char*) malloc(10000);
    	char *test_path=(char*) malloc(10000);
    	char *result_path=(char*) malloc(10000);
    	char  *model_path=(char *)malloc(10000);
    	char *accuracy_path=(char *)malloc(10000);
    	int i,j,k;
    	
    	memset(temp,0,10000);
    	memset(train_path,0,10000);
    	memset(result_path,0,10000);
    	memset(test_path,0,10000);
    	memset(model_path,0,10000);
    	memset(accuracy_path,0,10000);
    
    
    
    
    	
    
    /*	// train
    	svm_train(command_path_train, train_libsvm, model_libsvm);
    	// predict
    	svm_predict(command_path_predict, test_libsvm, model_libsvm, result_path_1, accuracy_path_1);
    
    	// 停住
    	system("pause");
    	
    			
    	return 1;*/
    
    
    /*************************************生成文件名****************************************************/
    	for(i=0;i<5;i++)//指征 done_research_times
    	{  
    		for(j=0;j<4;j++)//指征文档集规模
    		{
    			for( k=0;k<15;k++)//指征特征词维数
    			{   
    				strcat(temp,"D:\\");
    				strcat(temp,done_research_times[i]);
    				strcat(temp,"_");
    				strcat(temp,N_corpus[j]);
    				strcat(temp,"_rfinish");
    				strcat(temp,"\\TextCategorization_");
    				strcat(temp,done_research_times[i]);
    				strcat(temp,"_");
    				strcat(temp,N_corpus[j]);
    				strcat(temp,"_");
    				strcat(temp,featureDimensions[k]);
    				strcpy(file_address[i*60+j*15+k],temp);
    				//printf("%s\n",temp);
    				memset(temp,0,10000);
    			
    			}
    
    		}
    	}
    			
    free(temp);
    for(i=0;i<300;i++)
    {
    	//printf("%s\n", file_address[i]);
    	strcat(train_path,file_address[i]);
    	strcat(train_path,train_path_part);
    	strcat(test_path,file_address[i]);
    	strcat(test_path,test_path_part);
    	strcat(result_path,file_address[i]);
    	strcat(result_path,result_path_part);
    	strcat(model_path,file_address[i]);
    	strcat(model_path,model_path_part);
    	strcat(accuracy_path,file_address[i]);
    	strcat(accuracy_path,accuracy_path_part);
    	// train
    	svm_train(command_path_train, train_path, model_path);
    	// predict
    	svm_predict(command_path_predict, test_path, model_path, result_path, accuracy_path);
        printf("\n%s路径下的LibSVM分类完成\n",file_address[i]);
    	memset(train_path,0,10000);
    	memset(result_path,0,10000);
    	memset(test_path,0,10000);
    	memset(model_path,0,10000);
    	memset(accuracy_path,0,10000);
    
    	
    
    }
    free(train_path);
    free(test_path);
    free(model_path);
    free(result_path);
    free(accuracy_path);
    printf("试验完成\n");
    
    return 1;
    	
    }
    void AccuracyFormation()
    {
    	char *accuracy_path_part="\\data\\accuracy.txt";
        //char  featureDimensions[11][20]={"100","500","1000","1500","2000","2500","3000","3500","4000","4500","5000"};//特征维数
    	char featureDimensions[15][20]={"10","20","30","40","50","60","70","80","90","100","110","120","130","140","150"};//特征维数
    	char done_research_times[5][10]={"0","1","2","3","4"};//已经进行了几次试验
    	char N_corpus[4][20]={"100","500","1000","1500"};//文档集规模
    	char *accuracy_path=(char *)malloc(10000);
    	char dest_accuracy[5][20]={"0.txt","1.txt","2.txt","3.txt","4.txt"};
    	int i,j,k;
    	int reallen=0;
    	FILE *fp=NULL;
    	char *temp=(char *)malloc(1000);
    	memset(accuracy_path,0,10000);
    	memset(temp,0,100);
    	for(i=0;i<5;i++)//指征 done_research_times
    	{  
    		for(j=0;j<4;j++)//指征文档集规模
    		{
    			for( k=0;k<15;k++)//指征特征词维数
    			{   //构造路径
    				strcat(accuracy_path,"D:\\");
    				strcat(accuracy_path,done_research_times[i]);
    				strcat(accuracy_path,"_");
    				strcat(accuracy_path,N_corpus[j]);
    				strcat(accuracy_path,"_r1");
    				strcat(accuracy_path,"\\TextCategorization_");
    				strcat(accuracy_path,done_research_times[i]);
    				strcat(accuracy_path,"_");
    				strcat(accuracy_path,N_corpus[j]);
    				strcat(accuracy_path,"_");
    				strcat(accuracy_path,featureDimensions[k]);
    				strcat(accuracy_path,accuracy_path_part);
    				
    				fp=fopen(accuracy_path,"r");
    				if(fp==NULL)
    				{
    					printf("FILENAEM ERROR");
    					exit(0);
    				}
    				
    				fread(temp,1,100,fp);
    				fclose(fp);
    				fp=fopen(dest_accuracy[i],"a");
    				if(fp==NULL)
    				{
    					printf("FILENAEM ERROR");
    					exit(0);
    
    				}
    				if(k<14)//添加逗号
    				{
    					strcat(temp,",");
    				}
    				fwrite(temp,1,strlen(temp),fp);
    				fclose(fp);
    				printf("%s处理完毕\n",accuracy_path);
    				memset(accuracy_path,0,10000);
    				memset(temp,0,1000);
    			
    			}
    				fp=fopen(dest_accuracy[i],"a");
    				if(fp==NULL)
    				{
    					printf("FILENAEM ERROR");
    					exit(0);
    
    				}
    					strcat(temp,"\r\n");
    
    				
    				
    				
    				fwrite(temp,1,strlen(temp),fp);
    				fclose(fp);
    				printf("一行处理完毕\n");
    				memset(temp,0,1000);
    
    			
    
    		}
    
    		printf("%s填写完毕\n",dest_accuracy[i]);
    	}
    
    free(temp);			
    free(accuracy_path);
    	
    
    }
    

      

  • 相关阅读:
    2019.4.1 JMeter中文乱码解决方案
    19.3.25 sql查询语句
    2019.3.23 python的unittest框架与requests
    2019.3.22 JMeter基础操作
    19.3.21 计算机网络基础知识
    19.3.20 cmd操作:1.dir查看当前文件夹内的文件;2.alt+space+c关闭cmd窗口
    19.3.20 解决pycharm快捷键无法使用问题和熟悉git与码云操作流程
    19.3.19 使用Flask框架搭建一个简易登录服务器
    回调函数
    var img = new Image()
  • 原文地址:https://www.cnblogs.com/xiangshancuizhu/p/2227923.html
Copyright © 2011-2022 走看看