zoukankan      html  css  js  c++  java
  • TensorFlow 实现深度神经网络 —— Denoising Autoencoder

    完整代码请见 models/DenoisingAutoencoder.py at master · tensorflow/models · GitHub

    1. Denoising Autoencoder 类设计与构造函数

    • 简单起见,这里仅考虑一种单隐层的去噪自编码器结构;
      • 即整个网络拓扑结构为:输入层,单隐层,输出层;
        • 输入层 ⇒ 单隐层,可视为编码的过程,需要非线性的激励函数
        • 单隐层 ⇒ 输出层,可视为解码的过程,也可称之为某种意义上的重构(reconstruction),无需激励函数
    class DenoisingAutoencoder():
        def __init__(self, n_input, transfer_fn, ):
            ...
    
            # model
            self.x = tf.placeholder(dtype=tf.float32, shape=[None, self.n_input])
            self.x_corrupted = 
            self.hidden = self.transfer(tf.add(tf.matmul(self.x_corrupted , self.weights['w1']), self.weights['b1']))
            self.reconstruction = tf.add(tf.matmul(self.hidden, self.weights['w2']), self.weights['b2'])
    
            # cost
            self.cost = .5*tf.reduce_mean(tf.pow(tf.subtract(self.reconstruction, self.x), 2))

    2. 实现细节

    • 对于 autoencoder,自编码器属于无监督学习范畴,通过限定或者约束目标输出值等于输入数据,实现对原始输入信号的自动编码,从特征学习的观点来看,学到的编码也可视为一种对原始输入信号的层次化特征表示。

      在代码中,表现为,损失函数的定义上,self.cost = .5*tf.reduce_mean(tf.pow(tf.subtract(self.reconstruction, self.x), 2))

    3. 两种加噪的方式

    去噪自编码器模型的输入是原始输入经某种形式的加噪过程后的退化形式,加噪过程一般分为:

    • 加性高斯噪声(additive gaussian noise)

      self.scale = tf.placeholder(dtype=tf.float32)
      self.x_corrupted = tf.add(self.x, self.scale * tf.random_normal(shape=(self.n_input, )))
    • 掩模噪声(mask)

      self.keep_prob = tf.placeholder(dtype=tf.float32)
      self.x_corrupted = tf.nn.dropout(self.x, self.keep_prob)

    4. 椒盐噪声(salt & pepper)

    def salt_and_pepper_noise(X, v):
        X_noise = X.copy()
        n_features = X.shape[1]
        mn = X.min()
        mx = X.max()
    
        for i, sample in enumerate(X):
            mask = np.random.randint(0, n_features, v)
            for m in mask:
                if np.random.rand() < .5:
                    X_noise[i][m] = mn
                else:
                    X_noise[i][m] = mx
        return X_noise
    

    utilities.py

  • 相关阅读:
    通过Ambari2.2.2部署HDP大数据服务
    Ganglia监控安装配置
    Kafka安装配置
    Graylog2日志服务安装配置
    Dnsmasq域名解析系统安装配置
    在haoodp-2.7.3 HA的基础上安装Hbase HA
    MySQL5.6基于mysql-proxy实现读写分离
    MySQL5.6基于MHA方式高可用搭建
    CentOS使用yum安装drbd
    MySQL5.6基于GTID的主从复制配置
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9421869.html
Copyright © 2011-2022 走看看