zoukankan      html  css  js  c++  java
  • 深度学习结合非局部均值滤波的图像去噪算法

    其实这是半年之前完成的内容,一直懒着没有总结,今天看了看代码,发觉再不总结自己以后都看不懂了,故整理如下。

    非局部均值是一种基于块匹配来确定滤波权值的。即先确定一个块的大小,例如7x7,然后在确定一个搜索区域,例如15x15,在15x15这个搜索区域中的每一个点,计算7x7的窗口与当前滤波点7x7窗口的相似性(使用绝对差和SAD,一般而言,窗口中各点的差值还需要乘以经高斯核生成的权重参数,离中心点越近,权重值越大一些),然后根据相似性值使用指数函数生成窗口中心点的权重参数,相似性越高,该中心点的权重越大,最后各中心点的加权平均就是最终滤波图像,能获得很好的视觉效果。

    非局部均值的成功之处主要在于充分利用了块的相似性,而后续步骤由相似性计算对应权重值,按照经验使用指数函数,其参数h有着至关重要的作用,许多论文也是在h上面做改进。如果我们跳出加权平均和指数函数的思路,完全可以将含噪图像所有相邻点的像素值、相似性值、距离等做为输入送给深度学习网络,将原图像值作为输出进行训练啊,训练好的模型就可以直接用于滤波。

    下面附一个简化版的python代码,经实测改进后的算法比原生的非局部均值滤波要好,里面的网络模型过于简单,想提升效果的自己修改调优吧。

    注意使用的是python3环境

    #coding:utf8
    import cv2, datetime,sys,glob
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    
    from keras.models import Sequential, model_from_json
    from keras.layers import Dense, Activation,Dropout,Flatten,Merge
    from keras.callbacks import EarlyStopping
    from keras.layers.convolutional import Convolution2D,Convolution3D
    
    def psnr(A, B):
        return 10*np.log(255*255.0/(((A.astype(np.float)-B)**2).mean()))/np.log(10)
    def double2uint8(I, ratio=1.0):
        return np.clip(np.round(I*ratio), 0, 255).astype(np.uint8)
    
    def GetNlmData(I, templateWindowSize=4,  searchWindowSize=9):
        f = int(templateWindowSize / 2)
        t = int(searchWindowSize / 2)
        height, width = I.shape[:2]
        padLength = t + f
        I2 = np.pad(I, padLength, 'symmetric')
        I_ = I2[padLength - f:padLength + f + height, padLength - f:padLength + f + width]
    
        res = np.zeros((height, width, templateWindowSize+2, t+t+1, t+t+1))
        for i in range(-t, t + 1):
            for j in range(-t, t + 1):
                I2_ = I2[padLength + i - f:padLength + i + f + height, padLength + j - f:padLength + j + f + width]
                for kk in range(templateWindowSize):
                    kernel = np.ones((2*kk+1, 2*kk+1))
                    kernel = kernel/kernel.sum()
                    res[:, :, kk, i+t, j+t] = cv2.filter2D((I2_-I_) ** 2, -1,  kernel)[f:f + height, f:f + width]
                res[:, :, -2, i+t, j+t] = I2_[f:f + height, f:f + width]-I
                res[:, :, -1, i+t, j+t] = np.exp(-np.sqrt(i**2+j**2))
        print(res.max(), res.min())
        return res
    
    def zmTrain(trainX, trainY):
        model = Sequential()
        if 1:
            model.add(Dense(100, init='uniform', input_dim=trainX.shape[1]))
            model.add(Activation('relu'))
            model.add(Dense(50))
            model.add(Activation('relu'))
            model.add(Dense(1))
            model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
        else:
            with open('model.json', 'rb') as fd:
                model = model_from_json(fd.read())
                model.load_weights('weight.h5')
                model.compile(loss='msle', optimizer='adam', metrics=['accuracy'])
    
        early_stopping = EarlyStopping(monitor='val_loss', patience=5)
        hist =model.fit(trainX, trainY, batch_size=150, epochs=200, shuffle=True, verbose=2, validation_split=0.1
                        ,callbacks=[early_stopping])
        print(hist.history)
    
        res = model.predict(trainX)
        res = np.clip(np.round(res.ravel() * 255), 0, 255)
        print(psnr(res, trainY*255))
        return model
    if __name__ == '__main__':
        sigma = 20.0
        if 1:                         #这部分代码用于训练模型
            trainX = None
            trainY = None
    
            for d in glob.glob('./img/_*'):
                I = cv2.imread(d,0)
                I1 = double2uint8(I + np.random.randn(*I.shape) *sigma)
                data = GetNlmData(I1.astype(np.double)/255)
                s = data.shape
                data.resize((np.prod(s[:2]), np.prod(s[2:])))
                
                if trainX is None:
                    trainX = data
                    trainY = ((I.astype(np.double)-I1)/255).ravel()
                else:
                    trainX = np.concatenate((trainX, data), axis=0)
                    trainY = np.concatenate((trainY,  ((I.astype(np.double)-I1)/255).ravel()), axis=0)
                
            
            model = zmTrain(trainX, trainY)
            with open('model.json', 'wb') as fd:
                #fd.write(model.to_json())
                fd.write(bytes(model.to_json(),'utf8'))
            model.save_weights('weight.h5')
        if 1:                       #滤波
            with open('model.json', 'rb') as fd:
                model = model_from_json(fd.read().decode())
                model.load_weights('weight.h5')
            I = cv2.imread('lena.jpg', 0)
            I1 = double2uint8(I + np.random.randn(*I.shape) * sigma)
    
            data= GetNlmData(I1.astype(np.double)/255)
            s = data.shape
            data.resize((np.prod(s[:2]), np.prod(s[2:])))
            res = model.predict(data)
            res.resize(I.shape)
            res = np.clip(np.round(res*255 +I1), 0, 255)
            print('nwNLM PSNR', psnr(res, I))
            res = res.astype(np.uint8)
            cv2.imwrite('cvOut.bmp', res)
    



  • 相关阅读:
    无法在WEB服务器上启动调试
    Zedgraph悬停时显示内容闪烁的解决
    用ZedGraph控件作图圆
    34.node.js之Url & QueryString模块
    33.Node.js 文件系统fs
    32.Node.js中的常用工具类util
    31.Node.js 常用工具 util
    30.Node.js 全局对象
    28.Node.js 函数和匿名函数
    27.Node.js模块系统
  • 原文地址:https://www.cnblogs.com/zmshy2128/p/7118675.html
Copyright © 2011-2022 走看看