zoukankan      html  css  js  c++  java
  • 受限玻尔兹曼机RBM

    相关算法

    python代码参考http://blog.csdn.net/zc02051126/article/details/9668439#(作少量修改与注释)

      1 #coding:utf8
      2 import matplotlib.pylab as plt
      3 import numpy as np
      4 import cPickle
      5 
      6 
      7 class RBM:
      8     def __init__(self,n_visul, n_hidden, max_epoch = 50, batch_size = 110, penalty = 2e-4):
      9         self.n_visible = n_visul
     10         self.n_hidden = n_hidden
     11         self.max_epoch = max_epoch
     12         self.batch_size = batch_size
     13         self.penalty = penalty
     14         self.w = np.random.random((self.n_visible, self.n_hidden)) * 0.1
     15         self.v_bias = np.zeros((1, self.n_visible))
     16         self.h_bias = np.zeros((1, self.n_hidden))
     17 
     18     def sigmoid(self, z):
     19         return 1.0 / (1.0 + np.exp( -z ))
     20 
     21     def forward(self, vis):
     22         return self.sigmoid(np.dot(vis.T, self.w) + self.h_bias)
     23 
     24     def backward(self, vis):
     25         return self.sigmoid(np.dot(vis, self.w.T) + self.v_bias)
     26 
     27     def batch(self):
     28         d, N = self.x.shape
     29         num_batchs = int(round(N / self.batch_size)) + 1
     30         groups = np.ravel(np.repeat([range(0, num_batchs)], self.batch_size, axis = 0))
     31         groups=groups[:N]
     32         np.random.shuffle(groups)
     33         batch_data = []
     34         for i in range(0, num_batchs):
     35             index = groups == i
     36             batch_data.append(self.x[:, index])
     37         return batch_data
     38 
     39     def rbmBB(self, x):
     40         self.x = x
     41         eta = 0.1
     42         momentum = 0.5  #动量项
     43         W = self.w
     44         b = self.h_bias
     45         c = self.v_bias
     46         Winc  = np.zeros((self.n_visible, self.n_hidden))
     47         binc = np.zeros(self.n_hidden)
     48         cinc = np.zeros(self.n_visible)
     49         batch_data = self.batch()
     50         num_batch = len(batch_data)
     51         errors = []
     52         for epoch in range(0, self.max_epoch):
     53             err_sum = 0.0
     54             for batch in range(0, num_batch):
     55                 num_dims, num_cases = batch_data[batch].shape
     56                 data = batch_data[batch]
     57                 # 已知可见层,采样出隐藏层
     58                 ph = self.forward(data)
     59                 ph_states = np.zeros((num_cases, self.n_hidden))
     60                 ph_states[ph > np.random.random((num_cases, self.n_hidden))] = 1
     61                 # 已知隐藏层,采样出可见层
     62                 neg_data = self.backward(ph_states)
     63                 neg_data_states = np.zeros((num_cases, num_dims))
     64                 neg_data_states[neg_data > np.random.random((num_cases, num_dims))] = 1
     65                 neg_data_states = neg_data_states.transpose()
     66                 nh = self.forward(neg_data_states)
     67                 # CD算法
     68                 dW = np.dot(data, ph) - np.dot(neg_data_states, nh)
     69                 dc = np.sum(data, axis = 1) - np.sum(neg_data_states, axis = 1)
     70                 db = np.sum(ph, axis = 0) - np.sum(nh, axis = 0)
     71                 # 刷新参数
     72                 Winc = momentum * Winc + eta * (dW / num_cases - self.penalty * W)
     73                 binc = momentum * binc + eta * (db / num_cases);
     74                 cinc = momentum * cinc + eta * (dc / num_cases);
     75                 W = W + Winc
     76                 b = b + binc
     77                 c = c + cinc
     78                 self.w = W
     79                 self.h_bais = b
     80                 self.v_bias = c
     81                 err = np.linalg.norm(data - neg_data.transpose())
     82                 err_sum += err
     83             print epoch, err_sum
     84             errors.append(err_sum)
     85         self.errors = errors
     86         self.hiden_value = self.forward(self.x)
     87         h_row, h_col = self.hiden_value.shape
     88         hiden_states = np.zeros((h_row, h_col))
     89         hiden_states[self.hiden_value > np.random.random((h_row, h_col))] = 1
     90         self.rebuild_value = self.backward(hiden_states)
     91 
     92     def visualize(self, X):  #可视化
     93         D, N = X.shape
     94         s = int(np.sqrt(D))
     95         num = int(np.ceil(np.sqrt(N)))
     96         a = np.zeros((num*s + num + 1, num * s + num + 1)) - 1.0
     97         x = 0
     98         y = 0
     99         for i in range(0, N):
    100             z = X[:,i]
    101             z = z.reshape(s,s,order='F')
    102             z = z.transpose()
    103             a[x*s+x:x*s+s+x , y*s+y:y*s+s+y] = z
    104             x = x + 1
    105             if(x >= num):
    106                 x = 0
    107                 y = y + 1
    108         return a
    109 
    110 def readData(path):
    111     data = []
    112     for line in open(path, 'r'):
    113         ele = line.split(' ')
    114         tmp = []
    115         for e in ele:
    116             if e != '':
    117                 tmp.append(float(e.strip(' ')))
    118         data.append(tmp)
    119     return data
    120 
    121 if __name__ == '__main__':
    122     f = open('mnist.pkl', 'rb')
    123     training_data, validation_data, test_data = cPickle.load(f)
    124     training_inputs = [np.reshape(x, 784) for x in training_data[0]]
    125     data =training_inputs[:5000]
    126     data = np.array(data)
    127     data = data.transpose()
    128     rbm = Rbm(784, 100,max_epoch = 50)
    129     rbm.rbmBB(data)
    130 
    131     a = rbm.visualize(data)  #(2060L, 2060L)
    132     fig = plt.figure(1)
    133     ax = fig.add_subplot(111)
    134     ax.imshow(a)
    135     plt.title('original data')
    136 
    137     rebuild_value = rbm.rebuild_value.transpose()
    138     b = rbm.visualize(rebuild_value)  #(2060L, 2060L)
    139     fig = plt.figure(2)
    140     ax = fig.add_subplot(111)
    141     ax.imshow(b)
    142     plt.title('rebuild data')
    143 
    144     hidden_value = rbm.hiden_value.transpose()
    145     c = rbm.visualize(hidden_value)  #(782L, 782L)
    146     fig = plt.figure(3)
    147     ax = fig.add_subplot(111)
    148     ax.imshow(c)
    149     plt.title('hidden data')
    150 
    151     w_value = rbm.w
    152     d = rbm.visualize(w_value)  #(291L, 291L)
    153     fig = plt.figure(4)
    154     ax = fig.add_subplot(111)
    155     ax.imshow(d)
    156     plt.title('weight value(w)')
    157     plt.show()

  • 相关阅读:
    autorun.inf删除方法
    Re_Write序列号
    最常用的正则表达式
    SQL聚合使用GROUP BY
    Ext.Net的Window控件的简单使用
    SQL统计查询一个表中的记录,然后减法运算
    C#金额转换为汉字大写
    Ext.Net的Button按钮的使用
    C# 参考之方法参数关键字:params、ref及out 引用
    C#连接ACCESS 2007数据库
  • 原文地址:https://www.cnblogs.com/qw12/p/5818692.html
Copyright © 2011-2022 走看看