zoukankan      html  css  js  c++  java
  • 图像分割之(四)OpenCV的GrabCut函数使用和源码解读



            分类:             图像处理             计算机视觉                   12031人阅读     评论(33)     收藏     举报    





          上一文对GrabCut做了一个了解。OpenCV中的GrabCut算法是依据《"GrabCut" - Interactive Foreground Extraction using Iterated Graph Cuts》这篇文章来实现的。现在我对源码做了些注释,以便我们更深入的了解该算法。一直觉得论文和代码是有比较大的差别的,个人觉得脱离代码看论文,最多能看懂70%,剩下20%或者更多就需要通过阅读代码来获得了,那还有10%就和每个人的基础和知识储备相挂钩了。








    void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,

                      InputOutputArray _bgdModel, InputOutputArray _fgdModel,

                      int iterCount, int mode )





















           其中源码包含了gcgraph.hpp这个构建图和max flow/min cut算法的实现文件,这个文件暂时没有解读,后面再更新了。

    //  By downloading, copying, installing or using the software you agree to this license.
    //  If you do not agree to this license, do not download, install,
    //  copy or use the software.
    //                        Intel License Agreement
    //                For Open Source Computer Vision Library
    // Copyright (C) 2000, Intel Corporation, all rights reserved.
    // Third party copyrights are property of their respective owners.
    // Redistribution and use in source and binary forms, with or without modification,
    // are permitted provided that the following conditions are met:
    //   * Redistribution's of source code must retain the above copyright notice,
    //     this list of conditions and the following disclaimer.
    //   * Redistribution's in binary form must reproduce the above copyright notice,
    //     this list of conditions and the following disclaimer in the documentation
    //     and/or other materials provided with the distribution.
    //   * The name of Intel Corporation may not be used to endorse or promote products
    //     derived from this software without specific prior written permission.
    // This software is provided by the copyright holders and contributors "as is" and
    // any express or implied warranties, including, but not limited to, the implied
    // warranties of merchantability and fitness for a particular purpose are disclaimed.
    // In no event shall the Intel Corporation or contributors be liable for any direct,
    // indirect, incidental, special, exemplary, or consequential damages
    // (including, but not limited to, procurement of substitute goods or services;
    // loss of use, data, or profits; or business interruption) however caused
    // and on any theory of liability, whether in contract, strict liability,
    // or tort (including negligence or otherwise) arising in any way out of
    // the use of this software, even if advised of the possibility of such damage.
    #include "precomp.hpp"
    #include "gcgraph.hpp"
    #include <limits>
    using namespace cv;
    This is implementation of image segmentation algorithm GrabCut described in
    "GrabCut — Interactive Foreground Extraction using Iterated Graph Cuts".
    Carsten Rother, Vladimir Kolmogorov, Andrew Blake.
     GMM - Gaussian Mixture Model
    class GMM
        static const int componentsCount = 5;
        GMM( Mat& _model );
        double operator()( const Vec3d color ) const;
        double operator()( int ci, const Vec3d color ) const;
        int whichComponent( const Vec3d color ) const;
        void initLearning();
        void addSample( int ci, const Vec3d color );
        void endLearning();
        void calcInverseCovAndDeterm( int ci );
        Mat model;
        double* coefs;
        double* mean;
        double* cov;
        double inverseCovs[componentsCount][3][3]; //协方差的逆矩阵
        double covDeterms[componentsCount];  //协方差的行列式
        double sums[componentsCount][3];
        double prods[componentsCount][3][3];
        int sampleCounts[componentsCount];
        int totalSampleCount;
    GMM::GMM( Mat& _model )
        const int modelSize = 3/*mean*/ + 9/*covariance*/ + 1/*component weight*/;
        if( _model.empty() )
            _model.create( 1, modelSize*componentsCount, CV_64FC1 );
        else if( (_model.type() != CV_64FC1) || (_model.rows != 1) || (_model.cols != modelSize*componentsCount) )
            CV_Error( CV_StsBadArg, "_model must have CV_64FC1 type, rows == 1 and cols == 13*componentsCount" );
        model = _model;
        coefs = model.ptr<double>(0);  //GMM的每个像素的高斯模型的权值变量起始存储指针
        mean = coefs + componentsCount; //均值变量起始存储指针
        cov = mean + 3*componentsCount;  //协方差变量起始存储指针
        for( int ci = 0; ci < componentsCount; ci++ )
            if( coefs[ci] > 0 )
                 calcInverseCovAndDeterm( ci ); 
    double GMM::operator()( const Vec3d color ) const
        double res = 0;
        for( int ci = 0; ci < componentsCount; ci++ )
            res += coefs[ci] * (*this)(ci, color );
        return res;
    double GMM::operator()( int ci, const Vec3d color ) const
        double res = 0;
        if( coefs[ci] > 0 )
            CV_Assert( covDeterms[ci] > std::numeric_limits<double>::epsilon() );
            Vec3d diff = color;
            double* m = mean + 3*ci;
            diff[0] -= m[0]; diff[1] -= m[1]; diff[2] -= m[2];
            double mult = diff[0]*(diff[0]*inverseCovs[ci][0][0] + diff[1]*inverseCovs[ci][1][0] + diff[2]*inverseCovs[ci][2][0])
                       + diff[1]*(diff[0]*inverseCovs[ci][0][1] + diff[1]*inverseCovs[ci][1][1] + diff[2]*inverseCovs[ci][2][1])
                       + diff[2]*(diff[0]*inverseCovs[ci][0][2] + diff[1]*inverseCovs[ci][1][2] + diff[2]*inverseCovs[ci][2][2]);
            res = 1.0f/sqrt(covDeterms[ci]) * exp(-0.5f*mult);
        return res;
    int GMM::whichComponent( const Vec3d color ) const
        int k = 0;
        double max = 0;
        for( int ci = 0; ci < componentsCount; ci++ )
            double p = (*this)( ci, color );
            if( p > max )
                k = ci;  //找到概率最大的那个,或者说计算结果最大的那个
                max = p;
        return k;
    void GMM::initLearning()
        for( int ci = 0; ci < componentsCount; ci++)
            sums[ci][0] = sums[ci][1] = sums[ci][2] = 0;
            prods[ci][0][0] = prods[ci][0][1] = prods[ci][0][2] = 0;
            prods[ci][1][0] = prods[ci][1][1] = prods[ci][1][2] = 0;
            prods[ci][2][0] = prods[ci][2][1] = prods[ci][2][2] = 0;
            sampleCounts[ci] = 0;
        totalSampleCount = 0;
    void GMM::addSample( int ci, const Vec3d color )
        sums[ci][0] += color[0]; sums[ci][1] += color[1]; sums[ci][2] += color[2];
        prods[ci][0][0] += color[0]*color[0]; prods[ci][0][1] += color[0]*color[1]; prods[ci][0][2] += color[0]*color[2];
        prods[ci][1][0] += color[1]*color[0]; prods[ci][1][1] += color[1]*color[1]; prods[ci][1][2] += color[1]*color[2];
        prods[ci][2][0] += color[2]*color[0]; prods[ci][2][1] += color[2]*color[1]; prods[ci][2][2] += color[2]*color[2];
    //这里相当于论文中“Iterative minimisation”的step 2
    void GMM::endLearning()
        const double variance = 0.01;
        for( int ci = 0; ci < componentsCount; ci++ )
            int n = sampleCounts[ci]; //第ci个高斯模型的样本像素个数
            if( n == 0 )
                coefs[ci] = 0;
    			coefs[ci] = (double)n/totalSampleCount; 
    			double* m = mean + 3*ci;
                m[0] = sums[ci][0]/n; m[1] = sums[ci][1]/n; m[2] = sums[ci][2]/n;
    			double* c = cov + 9*ci;
                c[0] = prods[ci][0][0]/n - m[0]*m[0]; c[1] = prods[ci][0][1]/n - m[0]*m[1]; c[2] = prods[ci][0][2]/n - m[0]*m[2];
                c[3] = prods[ci][1][0]/n - m[1]*m[0]; c[4] = prods[ci][1][1]/n - m[1]*m[1]; c[5] = prods[ci][1][2]/n - m[1]*m[2];
                c[6] = prods[ci][2][0]/n - m[2]*m[0]; c[7] = prods[ci][2][1]/n - m[2]*m[1]; c[8] = prods[ci][2][2]/n - m[2]*m[2];
    			double dtrm = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6]) + c[2]*(c[3]*c[7]-c[4]*c[6]);
                if( dtrm <= std::numeric_limits<double>::epsilon() )
    				// Adds the white noise to avoid singular covariance matrix.
                    c[0] += variance;
                    c[4] += variance;
                    c[8] += variance;
    void GMM::calcInverseCovAndDeterm( int ci )
        if( coefs[ci] > 0 )
            double *c = cov + 9*ci;
            double dtrm =
                  covDeterms[ci] = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6]) 
    								+ c[2]*(c[3]*c[7]-c[4]*c[6]);
            //在C++中,每一种内置的数据类型都拥有不同的属性, 使用<limits>库可以获
    		//b=3时 10*a/b == 20/b不成立。那怎么办呢?
    		CV_Assert( dtrm > std::numeric_limits<double>::epsilon() );
            inverseCovs[ci][0][0] =  (c[4]*c[8] - c[5]*c[7]) / dtrm;
            inverseCovs[ci][1][0] = -(c[3]*c[8] - c[5]*c[6]) / dtrm;
            inverseCovs[ci][2][0] =  (c[3]*c[7] - c[4]*c[6]) / dtrm;
            inverseCovs[ci][0][1] = -(c[1]*c[8] - c[2]*c[7]) / dtrm;
            inverseCovs[ci][1][1] =  (c[0]*c[8] - c[2]*c[6]) / dtrm;
            inverseCovs[ci][2][1] = -(c[0]*c[7] - c[1]*c[6]) / dtrm;
            inverseCovs[ci][0][2] =  (c[1]*c[5] - c[2]*c[4]) / dtrm;
            inverseCovs[ci][1][2] = -(c[0]*c[5] - c[2]*c[3]) / dtrm;
            inverseCovs[ci][2][2] =  (c[0]*c[4] - c[1]*c[3]) / dtrm;
      Calculate beta - parameter of GrabCut algorithm.
      beta = 1/(2*avg(sqr(||color[i] - color[j]||)))
    static double calcBeta( const Mat& img )
        double beta = 0;
        for( int y = 0; y < img.rows; y++ )
            for( int x = 0; x < img.cols; x++ )
                Vec3d color = img.at<Vec3b>(y,x);
                if( x>0 ) // left  >0的判断是为了避免在图像边界的时候还计算,导致越界
                    Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1);
                    beta += diff.dot(diff);  //矩阵的点乘,也就是各个元素平方的和
                if( y>0 && x>0 ) // upleft
                    Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1);
                    beta += diff.dot(diff);
                if( y>0 ) // up
                    Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x);
                    beta += diff.dot(diff);
                if( y>0 && x<img.cols-1) // upright
                    Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1);
                    beta += diff.dot(diff);
        if( beta <= std::numeric_limits<double>::epsilon() )
            beta = 0;
            beta = 1.f / (2 * beta/(4*img.cols*img.rows - 3*img.cols - 3*img.rows + 2) ); //论文公式(5)
        return beta;
      Calculate weights of noterminal vertices of graph.
      beta and gamma - parameters of GrabCut algorithm.
    static void calcNWeights( const Mat& img, Mat& leftW, Mat& upleftW, Mat& upW, 
    							Mat& uprightW, double beta, double gamma )
        //gammaDivSqrt2相当于公式(4)中的gamma * dis(i,j)^(-1),那么可以知道,
    	const double gammaDivSqrt2 = gamma / std::sqrt(2.0f);
        leftW.create( img.rows, img.cols, CV_64FC1 );
        upleftW.create( img.rows, img.cols, CV_64FC1 );
        upW.create( img.rows, img.cols, CV_64FC1 );
        uprightW.create( img.rows, img.cols, CV_64FC1 );
        for( int y = 0; y < img.rows; y++ )
            for( int x = 0; x < img.cols; x++ )
                Vec3d color = img.at<Vec3b>(y,x);
                if( x-1>=0 ) // left  //避免图的边界
                    Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1);
                    leftW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff));
                    leftW.at<double>(y,x) = 0;
                if( x-1>=0 && y-1>=0 ) // upleft
                    Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1);
                    upleftW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff));
                    upleftW.at<double>(y,x) = 0;
                if( y-1>=0 ) // up
                    Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x);
                    upW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff));
                    upW.at<double>(y,x) = 0;
                if( x+1<img.cols && y-1>=0 ) // upright
                    Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1);
                    uprightW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff));
                    uprightW.at<double>(y,x) = 0;
    //每个像素只能取GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD 四种枚举值,分别表示该像素
    //ICCV2001“Interactive Graph Cuts for Optimal Boundary & Region Segmentation of Objects in N-D Images”
    //Yuri Y. Boykov Marie-Pierre Jolly 
      Check size, type and element values of mask matrix.
    static void checkMask( const Mat& img, const Mat& mask )
        if( mask.empty() )
            CV_Error( CV_StsBadArg, "mask is empty" );
        if( mask.type() != CV_8UC1 )
            CV_Error( CV_StsBadArg, "mask must have CV_8UC1 type" );
        if( mask.cols != img.cols || mask.rows != img.rows )
            CV_Error( CV_StsBadArg, "mask must have as many rows and cols as img" );
        for( int y = 0; y < mask.rows; y++ )
            for( int x = 0; x < mask.cols; x++ )
                uchar val = mask.at<uchar>(y,x);
                if( val!=GC_BGD && val!=GC_FGD && val!=GC_PR_BGD && val!=GC_PR_FGD )
                    CV_Error( CV_StsBadArg, "mask element value must be equel"
                        "GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD" );
    //rect内的设置为 GC_PR_FGD(可能为前景)
      Initialize mask using rectangular.
    static void initMaskWithRect( Mat& mask, Size imgSize, Rect rect )
        mask.create( imgSize, CV_8UC1 );
        mask.setTo( GC_BGD );
        rect.x = max(0, rect.x);
        rect.y = max(0, rect.y);
        rect.width = min(rect.width, imgSize.width-rect.x);
        rect.height = min(rect.height, imgSize.height-rect.y);
        (mask(rect)).setTo( Scalar(GC_PR_FGD) );
      Initialize GMM background and foreground models using kmeans algorithm.
    static void initGMMs( const Mat& img, const Mat& mask, GMM& bgdGMM, GMM& fgdGMM )
        const int kMeansItCount = 10;  //迭代次数
        const int kMeansType = KMEANS_PP_CENTERS; //Use kmeans++ center initialization by Arthur and Vassilvitskii
        Mat bgdLabels, fgdLabels; //记录背景和前景的像素样本集中每个像素对应GMM的哪个高斯模型,论文中的kn
        vector<Vec3f> bgdSamples, fgdSamples; //背景和前景的像素样本集
        Point p;
        for( p.y = 0; p.y < img.rows; p.y++ )
            for( p.x = 0; p.x < img.cols; p.x++ )
    			if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )
                    bgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );
                else // GC_FGD | GC_PR_FGD
                    fgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );
        CV_Assert( !bgdSamples.empty() && !fgdSamples.empty() );
        Mat _bgdSamples( (int)bgdSamples.size(), 3, CV_32FC1, &bgdSamples[0][0] );
        kmeans( _bgdSamples, GMM::componentsCount, bgdLabels,
                TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType );
        Mat _fgdSamples( (int)fgdSamples.size(), 3, CV_32FC1, &fgdSamples[0][0] );
        kmeans( _fgdSamples, GMM::componentsCount, fgdLabels,
                TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType );
        for( int i = 0; i < (int)bgdSamples.size(); i++ )
            bgdGMM.addSample( bgdLabels.at<int>(i,0), bgdSamples[i] );
        for( int i = 0; i < (int)fgdSamples.size(); i++ )
            fgdGMM.addSample( fgdLabels.at<int>(i,0), fgdSamples[i] );
    //论文中:迭代最小化算法step 1:为每个像素分配GMM中所属的高斯模型,kn保存在Mat compIdxs中
      Assign GMMs components for each pixel.
    static void assignGMMsComponents( const Mat& img, const Mat& mask, const GMM& bgdGMM, 
    									const GMM& fgdGMM, Mat& compIdxs )
        Point p;
        for( p.y = 0; p.y < img.rows; p.y++ )
            for( p.x = 0; p.x < img.cols; p.x++ )
                Vec3d color = img.at<Vec3b>(p);
                compIdxs.at<int>(p) = mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ?
                    bgdGMM.whichComponent(color) : fgdGMM.whichComponent(color);
    //论文中:迭代最小化算法step 2:从每个高斯模型的像素样本集中学习每个高斯模型的参数
      Learn GMMs parameters.
    static void learnGMMs( const Mat& img, const Mat& mask, const Mat& compIdxs, GMM& bgdGMM, GMM& fgdGMM )
        Point p;
        for( int ci = 0; ci < GMM::componentsCount; ci++ )
            for( p.y = 0; p.y < img.rows; p.y++ )
                for( p.x = 0; p.x < img.cols; p.x++ )
                    if( compIdxs.at<int>(p) == ci )
                        if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )
                            bgdGMM.addSample( ci, img.at<Vec3b>(p) );
                            fgdGMM.addSample( ci, img.at<Vec3b>(p) );
      Construct GCGraph
    static void constructGCGraph( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, double lambda,
                           const Mat& leftW, const Mat& upleftW, const Mat& upW, const Mat& uprightW,
                           GCGraph<double>& graph )
        int vtxCount = img.cols*img.rows;  //顶点数,每一个像素是一个顶点
        int edgeCount = 2*(4*vtxCount - 3*(img.cols + img.rows) + 2);  //边数,需要考虑图边界的边的缺失
    	graph.create(vtxCount, edgeCount);
        Point p;
        for( p.y = 0; p.y < img.rows; p.y++ )
            for( p.x = 0; p.x < img.cols; p.x++)
                // add node
                int vtxIdx = graph.addVtx();  //返回这个顶点在图中的索引
                Vec3b color = img.at<Vec3b>(p);
                // set t-weights			
    			double fromSource, toSink;
                if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )
    				fromSource = -log( bgdGMM(color) );
                    toSink = -log( fgdGMM(color) );
                else if( mask.at<uchar>(p) == GC_BGD )
    				fromSource = 0;
                    toSink = lambda;
                else // GC_FGD
                    fromSource = lambda;
                    toSink = 0;
                graph.addTermWeights( vtxIdx, fromSource, toSink );
                // set n-weights  n-links
    			if( p.x>0 )
                    double w = leftW.at<double>(p);
                    graph.addEdges( vtxIdx, vtxIdx-1, w, w );
                if( p.x>0 && p.y>0 )
                    double w = upleftW.at<double>(p);
                    graph.addEdges( vtxIdx, vtxIdx-img.cols-1, w, w );
                if( p.y>0 )
                    double w = upW.at<double>(p);
                    graph.addEdges( vtxIdx, vtxIdx-img.cols, w, w );
                if( p.x<img.cols-1 && p.y>0 )
                    double w = uprightW.at<double>(p);
                    graph.addEdges( vtxIdx, vtxIdx-img.cols+1, w, w );
    //论文中:迭代最小化算法step 3:分割估计:最小割或者最大流算法
      Estimate segmentation using MaxFlow algorithm
    static void estimateSegmentation( GCGraph<double>& graph, Mat& mask )
        Point p;
        for( p.y = 0; p.y < mask.rows; p.y++ )
            for( p.x = 0; p.x < mask.cols; p.x++ )
    			if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )
                    if( graph.inSourceSegment( p.y*mask.cols+p.x /*vertex index*/ ) )
                        mask.at<uchar>(p) = GC_PR_FGD;
                        mask.at<uchar>(p) = GC_PR_BGD;
    void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,
                      InputOutputArray _bgdModel, InputOutputArray _fgdModel,
                      int iterCount, int mode )
        Mat img = _img.getMat();
        Mat& mask = _mask.getMatRef();
        Mat& bgdModel = _bgdModel.getMatRef();
        Mat& fgdModel = _fgdModel.getMatRef();
        if( img.empty() )
            CV_Error( CV_StsBadArg, "image is empty" );
        if( img.type() != CV_8UC3 )
            CV_Error( CV_StsBadArg, "image mush have CV_8UC3 type" );
        GMM bgdGMM( bgdModel ), fgdGMM( fgdModel );
        Mat compIdxs( img.size(), CV_32SC1 );
        if( mode == GC_INIT_WITH_RECT || mode == GC_INIT_WITH_MASK )
            if( mode == GC_INIT_WITH_RECT )
                initMaskWithRect( mask, img.size(), rect );
            else // flag == GC_INIT_WITH_MASK
                checkMask( img, mask );
            initGMMs( img, mask, bgdGMM, fgdGMM );
        if( iterCount <= 0)
        if( mode == GC_EVAL )
            checkMask( img, mask );
        const double gamma = 50;
        const double lambda = 9*gamma;
        const double beta = calcBeta( img );
        Mat leftW, upleftW, upW, uprightW;
        calcNWeights( img, leftW, upleftW, upW, uprightW, beta, gamma );
        for( int i = 0; i < iterCount; i++ )
            GCGraph<double> graph;
            assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs );
            learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM );
            constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph );
            estimateSegmentation( graph, mask );



  • 相关阅读:
    LeetCode15 3Sum
    LeetCode10 Regular Expression Matching
    LeetCode20 Valid Parentheses
    LeetCode21 Merge Two Sorted Lists
    LeetCode13 Roman to Integer
    LeetCode12 Integer to Roman
    LeetCode11 Container With Most Water
    LeetCode19 Remove Nth Node From End of List
    LeetCode14 Longest Common Prefix
    LeetCode9 Palindrome Number
  • 原文地址:https://www.cnblogs.com/wangyaning/p/4237008.html
Copyright © 2011-2022 走看看