zoukankan      html  css  js  c++  java
  • 【小波变换】STL版 一维离散小波变换(DWT)库,完全按matlab的wavelet toolbox 的API实现的

    【转】一维离散小波变换(DWT)库,完全按matlab的wavelet toolbox 的API实现的

    来源:http://hi.baidu.com/anatacollin/item/69fdab74ca7d045c0d0a07b4

    一维离散小波变换(DWT)库,完全按matlab的wavelet toolbox 的API实现的2008-12-01 20:37最近项目中需要用,就自己写了个,发在这里算是备忘。需要的朋友也可以拿去试试,经测试没有发现bug,基于STL实现。如果发现bug或有什么建议请通知我,谢谢。

    /************************************************************************/
    /* wavelet.h
    * Author: Collin
    * Date: 2008/12/01
    */
    /************************************************************************/
    #ifndef WAVELET_H
    #define WAVELET_H
    #include <vector>

    namespace Wavelet{
        using std::vector;
        struct C_L
        {
            vector<double> C;
            vector<int> L;
        };
        struct WaveFilter{
            vector<double> Low;
            vector<double> High;
        };
        struct WaveCoeff{
            vector<double> app;
            vector<double> det;
        };
        const double sym4_Lo_D[] = {-0.0758, -0.0296, 0.4976, 0.8037, 0.2979, -0.0992, -0.0126, 0.0322};
        const double sym4_Hi_D[] = {-0.0322, -0.0126, 0.0992, 0.2979, -0.8037, 0.4976, 0.0296, -0.0758};
        const double sym4_Lo_R[] = {0.0322, -0.0126, -0.0992, 0.2979, 0.8037, 0.4976, -0.0296, -0.0758};
        const double sym4_Hi_R[] = {-0.0758, 0.0296, 0.4976, -0.8037, 0.2979, 0.0992, -0.0126, -0.0322};

        const static WaveFilter sym4_d = {vector<double>(sym4_Lo_D, sym4_Lo_D + 8), vector<double>(sym4_Hi_D, sym4_Hi_D + 8)};
        const static WaveFilter sym4_r = {vector<double>(sym4_Lo_R, sym4_Lo_R + 8), vector<double>(sym4_Hi_R, sym4_Hi_R + 8)};
        const WaveFilter& WFilters(
            const char* strWaveName,
            const char d_or_r
            );
        C_L WaveDec(
            const vector<double>& signal,
            const int nMaxLevel,
            const char* strWaveName
            );
        WaveCoeff DWT(
            const vector<double>& signal,
            const vector<double>& Lo_D,
            const vector<double>& Hi_D
            );
        vector<double> WRCoef(
            const char a_or_d,
            const vector<double>& C,
            const vector<int>& L,
            const char* strWaveName,
            const int nLevel
            );
        vector<double> AppCoef(
            const vector<double>& C,
            const vector<int>& L,
            const char* strWaveName,
            const int nLevel
            );
        vector<double> DetCoef(
            const vector<double>& C,
            const vector<int>& L,
            const int nLevel
            );
        //upsample and convolution
        vector<double> UpsConv1(
            const vector<double>& signal,
            const vector<double>& filter,
            const int nLen,
            const char* strMode = "sym"
            );
        vector<double> Conv(
            const vector<double>& vecSignal,
            const vector<double>& vecFilter
            );
        vector<double> IDWT(
            const vector<double>& app,
            const vector<double>& det,
            const vector<double>& Lo_R,
            const vector<double>& Hi_R,
            const int nLenCentral
            );
        vector<double> WExtend(
            const vector<double>& signal,
            const int nLenExt,
            const char* mode = "sym"
            );
        vector<double> WConv1(
            const vector<double>& signal,
            const vector<double>& filter,
            const char* shape = "valid"
            );
    }
    #endif



    /************************************************************************/
    /* wavelet.cpp
    * Author: Collin
    * Date: 2008/12/01
    */
    /************************************************************************/
    #include <vector>
    #include <string>
    #include <iostream>
    #include "wavelet.h"
    using namespace std;
    using namespace Wavelet;
    C_L Wavelet::WaveDec(const vector<double>& signal,
                        const int nMaxLevel,
                        const char* strWaveName
                        )
    {
        const WaveFilter& filters = WFilters(strWaveName, 'd');
        int len = signal.size();
        C_L cl;
        cl.L.push_back(len);
        WaveCoeff waveCoeff;
        waveCoeff.app = signal;
        vector<double>::iterator itC;
        vector<int>::iterator itL;
        for (int i = 0; i < nMaxLevel; ++i){
            waveCoeff = DWT(waveCoeff.app, filters.Low, filters.High);
            itC = cl.C.begin();
            cl.C.insert(itC, waveCoeff.det.begin(), waveCoeff.det.end());
            itL = cl.L.begin();
            cl.L.insert(itL, waveCoeff.det.size());
        }
        itC = cl.C.begin();
        cl.C.insert(itC, waveCoeff.app.begin(), waveCoeff.app.end());
        itL = cl.L.begin();
        cl.L.insert(itL, waveCoeff.app.size());
        return cl;
    }

    vector<double> Wavelet::WRCoef(const char a_or_d,
                                   const vector<double>& C,
                                   const vector<int>& L,
                                   const char* strWaveName,
                                   const int nLevel
                                   )
    {
        vector<double> Coef;
        const WaveFilter& filter = WFilters(strWaveName, 'r');
        int nMax = L.size() - 2;
        int nMin;
        char type = tolower(a_or_d);
        if ('a' == type)
            nMin = 0;
        else if ('d' == type)
            nMin = 1;
        else {
            cerr << "bad parameter: a_or_d: "<< a_or_d << "\n";
            exit(1);
        }
        if (nLevel < nMin || nLevel > nMax){
            cerr << "bad parameter for level\n";
            exit(1);
        }
        vector<double> F1;
        switch (type){
            case 'a':
                Coef = AppCoef(C, L, strWaveName, nLevel);
                if (0 == nLevel)
                    return Coef;
                F1 = filter.Low;
                break;
            case 'd':
                Coef = DetCoef(C, L, nLevel);
                F1 = filter.High;
                break;
            default:
                ;
        }
        int iMin = L.size() - nLevel;
        Coef = UpsConv1(Coef, F1, L[iMin], "sym");
        for (int k = 1; k < nLevel; ++k){
            Coef = UpsConv1(Coef, filter.Low, L[iMin + k], "sym");
        }
        return Coef;
    }
    vector<double> Wavelet::UpsConv1(const vector<double>& signal,
                                    const vector<double>& filter,
                                    const int nLen,
                                    const char* strMode
                                    )
    {
        //implement dyadup(y,0)
        vector<double> y(2 * signal.size() - 1);
        y[0] = signal[0];
        for (int i = 1; i < signal.size(); ++i){
            y[2*i - 1] = 0;
            y[2*i] = signal[i];
        }
        y = Conv(y, filter);

        //extract the central portion
        vector<double>::iterator it = y.begin();
        return vector<double>(it + (y.size() - nLen) / 2, it + (y.size() + nLen) / 2);
    }
    vector<double> Wavelet::Conv(const vector<double>& vecSignal, const vector<double>& vecFilter){
        vector<double> signal(vecSignal);
        vector<double> filter(vecFilter);
        if (signal.size() < filter.size())
            signal.swap(filter);   
        int lenSignal = signal.size();
        int lenFilter = filter.size();
        vector<double> result(lenSignal + lenFilter - 1);
        for (int i = 0; i < lenFilter; i++){
            for (int j = 0; j <= i; j++)
                result[i] += signal[j] * filter[i - j];
        }
        for (int i = lenFilter; i < lenSignal; i++){
            for (int j = 0; j <lenFilter; j++)
                result[i] += signal[i - j] * filter[j];
        }
        for (int i = lenSignal; i < lenSignal + lenFilter - 1; i++){
            for (int j = i - lenSignal + 1; j < lenFilter; j++)
                result[i] += signal[i - j] * filter[j];
        }
        return result;   
    }
    vector<double> Wavelet::DetCoef(const vector<double>& C,
                           const vector<int>& L,
                           const int nLevel
                           )
    {
        if (nLevel < 1 || nLevel > L.size() - 2){
            cerr << "bad level parameter\n";
            exit(1);
        }

        int nlast = 0, nfirst = 0;
        vector<int>::const_reverse_iterator it = L.rbegin();
        ++it;
        for (int i = 1; i < nLevel; ++i){
            nlast += *it;
            ++it;
        }
        nfirst = nlast + *it;
        return vector<double>(C.end() - nfirst, C.end() - nlast);
    }
    WaveCoeff Wavelet::DWT(const vector<double>& signal,
                const vector<double>& Lo_D,
                const vector<double>& Hi_D
                )
    {
        int nLenExt = Lo_D.size() - 1;
        vector<double> y;
        y = WExtend(signal, nLenExt, "sym");
        vector<double> z;
        z = WConv1(y, Lo_D, "valid");
        WaveCoeff coeff;
        for (int i = 1; i < z.size(); i += 2){
            coeff.app.push_back(z[i]);
        }
        z = WConv1(y, Hi_D, "valid");
        for (int i = 1; i < z.size(); i += 2){
            coeff.det.push_back(z[i]);
        }
        return coeff;
    }
    const WaveFilter& Wavelet::WFilters(const char* strWaveName,
                               const char d_or_r
                               )
    {
        char type = tolower(d_or_r);
        if (!strcmp(strWaveName, "sym4")){
            switch(type){
            case 'd':
                return Wavelet::sym4_d;
                break;
            case 'r':
                return Wavelet::sym4_r;
                break;
            default:
                cerr << "bad parameter for d_or_r\n";
                exit(1);
            }
        }
        else {
            cerr << "not implement \n";
            exit(1);
        }
    }

    vector<double> Wavelet::AppCoef(
                                    const vector<double>& C,
                                    const vector<int>& L,
                                    const char* strWaveName,
                                    const int nLevel
                                    )
    {
        int nMaxLevel = L.size() - 2;
        if (nLevel < 0 || nLevel > nMaxLevel){
            cerr << "bad parameter for level\n";
            exit(1);
        }
        const WaveFilter& filters = WFilters(strWaveName, 'r');
        vector<double> app(C.begin(), C.begin() + L[0]); //app for the last level
        vector<double> det;
        for (int i = 0; i < nMaxLevel - nLevel; ++i){
            det = DetCoef(C, L, nMaxLevel - i);
            app = IDWT(app, det, filters.Low, filters.High, L[i + 2]);
        }
        return app;
    }
    vector<double> Wavelet::IDWT(
                        const vector<double>& app,
                        const vector<double>& det,
                        const vector<double>& Lo_R,
                        const vector<double>& Hi_R,
                        const int nLenCentral
                        )
    {
        vector<double> app1, app2;
        app1 = UpsConv1(app, Lo_R, nLenCentral, "sym");
        app2 = UpsConv1(det, Hi_R, nLenCentral, "sym");
        for (int i = 0; i < nLenCentral; ++i){
            app1[i] += app2[i];
        }
        return app1;
    }
    vector<double> Wavelet::WExtend(
                           const vector<double>& signal,
                           const int nLenExt,
                           const char* mode
                           )
    {
        int signalLen = signal.size();
        vector<double> result(signalLen + 2 * nLenExt);
        for (int i = 0, idx = nLenExt; idx < signalLen + nLenExt; ++i, ++idx){
            result[idx] = signal[i];
        }
        for (int idx = nLenExt - 1, bFlag = 1, signalIdx = 0; idx >= 0; --idx){
            result[idx] = signal[signalIdx];
            if (bFlag && ++signalIdx == signalLen){
                bFlag = 0;
                signalIdx = signalLen - 1;
            }
            else if (!bFlag && --signalIdx == -1) {
                bFlag = 1;
                signalIdx = 0;
            }
        }
        for (int idx = nLenExt + signalLen, bFlag = 0, signalIdx = signalLen - 1; idx < 2 * nLenExt + signalLen; ++idx){
            result[idx] = signal[signalIdx];
            if (bFlag && ++signalIdx == signalLen){
                bFlag = 0;
                signalIdx = signalLen - 1;
            }
            else if (!bFlag && --signalIdx == -1) {
                bFlag = 1;
                signalIdx = 0;
            }
        }
        return result;
    }
    vector<double> Wavelet::WConv1(
                        const vector<double>& signal,
                        const vector<double>& filter,
                        const char* shape
                        )
    {
        vector<double> y;
        y = Conv(signal, filter);
        int nLenExt = filter.size() - 1;
        return vector<double>(y.begin() + nLenExt, y.end() - nLenExt);
    }

    本博客未标明转载的内容均为本站原创,转载时请署名(richard.hmm)并注明来源(www.cnblogs.com/IDoIUnderstand/)。请勿用于任何商业用途,作者(richard.hmm)保留本博客所有内容的一切权利。
  • 相关阅读:
    SpringIOC的小例子
    java中递归实现复制多级文件夹
    快速排序和几种简单排序
    Oracle面试的基本题
    多态的两个小例子
    单例模式
    内部类与匿名内部类
    C#
    C#
    C#
  • 原文地址:https://www.cnblogs.com/IDoIUnderstand/p/3280723.html
Copyright © 2011-2022 走看看