zoukankan      html  css  js  c++  java
  • opencv中读写svm信息xml

    从opencv3.4.9中摘取。

    写svm的xml信息:

     1 void write( FileStorage& fs ) const CV_OVERRIDE
     2     {
     3         int class_count = !class_labels.empty() ? (int)class_labels.total() :
     4                           params.svmType == ONE_CLASS ? 1 : 0;
     5         if( !isTrained() )
     6             CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
     7 
     8         writeFormat(fs);
     9         write_params( fs );
    10 
    11         fs << "var_count" << var_count;
    12 
    13         if( class_count > 0 )
    14         {
    15             fs << "class_count" << class_count;
    16 
    17             if( !class_labels.empty() )
    18                 fs << "class_labels" << class_labels;
    19 
    20             if( !params.classWeights.empty() )
    21                 fs << "class_weights" << params.classWeights;
    22         }
    23 
    24         // write the joint collection of support vectors
    25         int i, sv_total = sv.rows;
    26         fs << "sv_total" << sv_total;
    27         fs << "support_vectors" << "[";
    28         for( i = 0; i < sv_total; i++ )
    29         {
    30             fs << "[:";
    31             fs.writeRaw("f", sv.ptr(i), sv.cols*sv.elemSize());
    32             fs << "]";
    33         }
    34         fs << "]";
    35 
    36         if ( !uncompressed_sv.empty() )
    37         {
    38             // write the joint collection of uncompressed support vectors
    39             int uncompressed_sv_total = uncompressed_sv.rows;
    40             fs << "uncompressed_sv_total" << uncompressed_sv_total;
    41             fs << "uncompressed_support_vectors" << "[";
    42             for( i = 0; i < uncompressed_sv_total; i++ )
    43             {
    44                 fs << "[:";
    45                 fs.writeRaw("f", uncompressed_sv.ptr(i), uncompressed_sv.cols*uncompressed_sv.elemSize());
    46                 fs << "]";
    47             }
    48             fs << "]";
    49         }
    50 
    51         // write decision functions
    52         int df_count = (int)decision_func.size();
    53 
    54         fs << "decision_functions" << "[";
    55         for( i = 0; i < df_count; i++ )
    56         {
    57             const DecisionFunc& df = decision_func[i];
    58             int sv_count = getSVCount(i);
    59             fs << "{" << "sv_count" << sv_count
    60                << "rho" << df.rho
    61                << "alpha" << "[:";
    62             fs.writeRaw("d", (const uchar*)&df_alpha[df.ofs], sv_count*sizeof(df_alpha[0]));
    63             fs << "]";
    64             if( class_count >= 2 )
    65             {
    66                 fs << "index" << "[:";
    67                 fs.writeRaw("i", (const uchar*)&df_index[df.ofs], sv_count*sizeof(df_index[0]));
    68                 fs << "]";
    69             }
    70             else
    71                 CV_Assert( sv_count == sv_total );
    72             fs << "}";
    73         }
    74         fs << "]";
    75     }

    读svm的xml信息:

     1  void read( const FileNode& fn ) CV_OVERRIDE
     2     {
     3         clear();
     4 
     5         // read SVM parameters
     6         read_params( fn );
     7 
     8         // and top-level data
     9         int i, sv_total = (int)fn["sv_total"];
    10         var_count = (int)fn["var_count"];
    11         int class_count = (int)fn["class_count"];
    12 
    13         if( sv_total <= 0 || var_count <= 0 )
    14             CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
    15 
    16         FileNode m = fn["class_labels"];
    17         if( !m.empty() )
    18             m >> class_labels;
    19         m = fn["class_weights"];
    20         if( !m.empty() )
    21             m >> params.classWeights;
    22 
    23         if( class_count > 1 && (class_labels.empty() || (int)class_labels.total() != class_count))
    24             CV_Error( CV_StsParseError, "Array of class labels is missing or invalid" );
    25 
    26         // read support vectors
    27         FileNode sv_node = fn["support_vectors"];
    28 
    29         CV_Assert((int)sv_node.size() == sv_total);
    30 
    31         sv.create(sv_total, var_count, CV_32F);
    32         FileNodeIterator sv_it = sv_node.begin();
    33         for( i = 0; i < sv_total; i++, ++sv_it )
    34         {
    35             (*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
    36         }
    37 
    38         int uncompressed_sv_total = (int)fn["uncompressed_sv_total"];
    39 
    40         if( uncompressed_sv_total > 0 )
    41         {
    42             // read uncompressed support vectors
    43             FileNode uncompressed_sv_node = fn["uncompressed_support_vectors"];
    44 
    45             CV_Assert((int)uncompressed_sv_node.size() == uncompressed_sv_total);
    46             uncompressed_sv.create(uncompressed_sv_total, var_count, CV_32F);
    47 
    48             FileNodeIterator uncompressed_sv_it = uncompressed_sv_node.begin();
    49             for( i = 0; i < uncompressed_sv_total; i++, ++uncompressed_sv_it )
    50             {
    51                 (*uncompressed_sv_it).readRaw("f", uncompressed_sv.ptr(i), var_count*uncompressed_sv.elemSize());
    52             }
    53         }
    54 
    55         // read decision functions
    56         int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
    57         FileNode df_node = fn["decision_functions"];
    58 
    59         CV_Assert((int)df_node.size() == df_count);
    60 
    61         FileNodeIterator df_it = df_node.begin();
    62         for( i = 0; i < df_count; i++, ++df_it )
    63         {
    64             FileNode dfi = *df_it;
    65             DecisionFunc df;
    66             int sv_count = (int)dfi["sv_count"];
    67             int ofs = (int)df_index.size();
    68             df.rho = (double)dfi["rho"];
    69             df.ofs = ofs;
    70             df_index.resize(ofs + sv_count);
    71             df_alpha.resize(ofs + sv_count);
    72             dfi["alpha"].readRaw("d", (uchar*)&df_alpha[ofs], sv_count*sizeof(df_alpha[0]));
    73             if( class_count >= 2 )
    74                 dfi["index"].readRaw("i", (uchar*)&df_index[ofs], sv_count*sizeof(df_index[0]));
    75             decision_func.push_back(df);
    76         }
    77         if( class_count < 2 )
    78             setRangeVector(df_index, sv_total);
    79         if( (int)fn["optimize_linear"] != 0 )
    80             optimize_linear_svm();
    81     }
    82 
    83     SvmParams params;
    84     Mat class_labels;
    85     int var_count;
    86     Mat sv, uncompressed_sv;
    87     vector<DecisionFunc> decision_func;
    88     vector<double> df_alpha;
    89     vector<int> df_index;
    90 
    91     Ptr<Kernel> kernel;
    92 };
  • 相关阅读:
    iOS
    iOS
    iOS
    iOS
    iOS
    使用jquery获取radio的值
    CSS margin属性与用法教程
    CSS框架960Grid从入门到精通一步登天
    从程序员到项目经理
    华为离职副总裁徐家骏:年薪千万的工作感悟
  • 原文地址:https://www.cnblogs.com/juluwangshier/p/13094634.html
Copyright © 2011-2022 走看看