zoukankan      html  css  js  c++  java
  • darknet分类源码解析

      1 void validate_classifier_multi(char *datacfg, char *filename, char *weightfile)
      2 {
      3     int i, j;
      4     network net = parse_network_cfg(filename);
      5     set_batch_network(&net, 1);
      6     if(weightfile){
      7         load_weights(&net, weightfile);
      8     }
      9     srand(time(0));
     10 
     11     list *options = read_data_cfg(datacfg);//读.data文件到option列表中
     12 
     13     char *label_list = option_find_str(options, "labels", "data/labels.list");
     14     //从读到的.data生成的option列表去找对饮的字段如labels,将labels的配置路径放到label_list指针中,
     15     //然后如果labels的配置路径是"data/labels.list",打印“使用默认配置”字样
     16     char *valid_list = option_find_str(options, "valid", "data/train.list");// l,key,def;  return  def
     17     int classes = option_find_int(options, "classes", 2);
     18     int topk = option_find_int(options, "top", 1);
     19     if (topk > classes) topk = classes;//找的比类别还多
     20 
     21     char **labels = get_labels(label_list);
     22     //将labels.list标签名读到lables字符指针,可以通过labels[i]访问标签
     23     list *plist = get_paths(valid_list);//得到验证集的数据路径
     24     int scales[] = {224, 288, 320, 352, 384};
     25     int nscales = sizeof(scales)/sizeof(scales[0]);
     26 
     27     char **paths = (char **)list_to_array(plist);
     28     int m = plist->size;
     29     free_list(plist);
     30 
     31     float avg_acc = 0;
     32     float avg_topk = 0;
     33     int* indexes = (int*)calloc(topk, sizeof(int));
     34 
     35     for(i = 0; i < m; ++i){
     36         int class_id = -1;//一般用负数初始化
     37         char *path = paths[i];//这里的路径名包括文件名之外的路径吗?
     38         for(j = 0; j < classes; ++j){
     39             if(strstr(path, labels[j])){
     40                 //在path字符串中查找labels[j]字符串第一次出现的位置
     41                 class_id = j;
     42                 //这里实现了数据集在训练过程中的类别的确定。还是看匹配,只要标签在文件名中
     43                 break;
     44             }
     45         }
     46         float* pred = (float*)calloc(classes, sizeof(float));
     47         image im = load_image_color(paths[i], 0, 0);
     48         for(j = 0; j < nscales; ++j){
     49             image r = resize_min(im, scales[j]);
     50             resize_network(&net, r.w, r.h);
     51             float *p = network_predict(net, r.data);
     52             if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy, 1);
     53             axpy_cpu(classes, 1, p, 1, pred, 1);
     54             flip_image(r);
     55             p = network_predict(net, r.data);
     56             axpy_cpu(classes, 1, p, 1, pred, 1);
     57             if(r.data != im.data) free_image(r);
     58         }
     59         free_image(im);
     60         top_k(pred, classes, topk, indexes);
     61         free(pred);
     62         if(indexes[0] == class_id) avg_acc += 1;
     63         for(j = 0; j < topk; ++j){
     64             if(indexes[j] == class_id) avg_topk += 1;
     65         }
     66 
     67         printf("%d: top 1: %f, top %d: %f
    ", i, avg_acc/(i+1), topk, avg_topk/(i+1));
     68     }
     69 }
     70  
     71 
     72 void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top)
     73 {//反初始化主要是类对象的析构
     74     network net = parse_network_cfg_custom(cfgfile, 1, 0);
     75     if(weightfile){
     76         load_weights(&net, weightfile);
     77     }
     78     set_batch_network(&net, 1);
     79     srand(2222222);
     80 
     81     fuse_conv_batchnorm(net);
     82     calculate_binary_weights(net);
     83 
     84     list *options = read_data_cfg(datacfg);
     85 
     86     char *name_list = option_find_str(options, "names", 0);
     87     if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
     88     int classes = option_find_int(options, "classes", 2);
     89     if (top == 0) top = option_find_int(options, "top", 1);
     90     if (top > classes) top = classes;
     91 
     92     int i = 0;
     93     char **names = get_labels(name_list);
     94     clock_t time;
     95     int* indexes = (int*)calloc(top, sizeof(int));
     96     char buff[256];
     97     char *input = buff;
     98     //int size = net.w;
     99     while(1){
    100         if(filename){
    101             strncpy(input, filename, 256);//将filename的前256个字符复制到input中。
    102         }else{
    103             printf("Enter Image Path: ");
    104             fflush(stdout);
    105             input = fgets(input, 256, stdin);
    106             if(!input) return;
    107             strtok(input, "
    ");
    108         }
    109         image im = load_image_color(input, 0, 0);
    110         image r = letterbox_image(im, net.w, net.h);
    111         //image r = resize_min(im, size);
    112         //resize_network(&net, r.w, r.h);
    113         printf("%d %d
    ", r.w, r.h);
    114 
    115         float *X = r.data;
    116         time=clock();
    117         float *predictions = network_predict(net, X);
    118         if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 0);
    119         top_k(predictions, net.outputs, top, indexes);
    120         //按得分来排top k,indexes是新的排序指针,按升序排列,prediction越大的在indexes里面的id越是靠后。
    121         printf("%s: Predicted in %f seconds.
    ", input, sec(clock()-time));
    122         for(i = 0; i < top; ++i){
    123             int index = indexes[i];
    124             //hierarchy是一个树形结构体指针变量。应该是没有的。
    125             if(net.hierarchy) printf("%d, %s: %f, parent: %s 
    ",index, names[index], predictions[index], (net.hierarchy->parent[index] >= 0) ? names[net.hierarchy->parent[index]] : "Root");
    126             else printf("%s: %f
    ",names[index], predictions[index]);
    127             //names[index]是分类的对应的类别名称如yb,ye,yf
    128             //predictions[index]是推理置信度
    129         }
    130         if(r.data != im.data) free_image(r);
    131         free_image(im);
    132         if (filename) break;//可以批量测试,如果filename是False,跳出
    133     }
    134 }
    135  
  • 相关阅读:
    Flask初识之安装及HelloWord程序
    Python 四大主流 Web 编程框架
    Mysql之锁、事务绝版详解---干货!
    Django基础九之中间件
    Django基础八之cookie和session
    Django基础七之Ajax
    Django基础六之ORM中的锁和事务
    Linux下使用tail查找日志文件关键词有颜色、高亮显示
    主流云测平台汇总
    RPC框架总述
  • 原文地址:https://www.cnblogs.com/Henry-ZHAO/p/12725182.html
Copyright © 2011-2022 走看看