zoukankan      html  css  js  c++  java
  • DBN(深度信念网络)

    DBN运用CD算法逐层进行训练,得到每一层的参数Wi和ci用于初始化DBN,之后再用监督学习算法对参数进行微调。本例中采用softmax分类器(下一篇随笔中)作为监督学习算法。

    RBM与上一篇随笔中一致,通过多层RBM将softmax parameter从 (10L, 784L)降低到(10L, 50L)。单独用softmax分类器也可以得到相近(或者略好)的正确率,所需的时间略长一点。

     1 from rbm2 import RBM
     2 from softmax import SoftMax
     3 import os
     4 import numpy as np
     5 import cPickle
     6 
     7 class DBN:
     8     def __init__(self,nlayers,ntype,vlen,hlen):
     9         self.rbm_layers = []
    10         self.nlayers = nlayers
    11         self.ntype = ntype
    12         self.vlen=vlen
    13         self.hlen=hlen
    14 
    15     def calcRBMForward(self,x):
    16         for rbm in self.rbm_layers:
    17             x = rbm.forward(x.T)
    18         return x
    19 
    20     def load_param(self,dbnpath,softmaxpath):
    21         weights = cPickle.load(open(dbnpath,'rb'))
    22         self.nlayers = len(weights)
    23         for i in range(self.nlayers):
    24             weight = weights[i]
    25             v,h= np.shape(weight)
    26             rbm = RBM(v,h)
    27             rbm.w = weight
    28             self.rbm_layers.append(rbm)
    29             print "RBM layer%d shape:%s" %(i,str(rbm.w.shape))
    30         self.softmax = SoftMax()
    31         self.softmax.load_theta(softmaxpath)
    32         print "softmax parameter: "+str(self.softmax.theta.shape)
    33 
    34     def pretrainRBM(self,trainset):
    35         weights = []
    36         for i in range(self.nlayers):
    37             rbm = RBM(self.vlen,self.hlen)
    38             if i == 0:
    39                 traindata = trainset
    40             else:
    41                 traindata = np.array(outdata.T)
    42             rbm.rbmBB(traindata)
    43             outdata = np.mat(rbm.forward(traindata))
    44             self.rbm_layers.append(rbm)
    45             weights.append(rbm.w)
    46             self.vlen = self.hlen
    47             self.hlen = self.hlen/2
    48         f= open("data/dbn.pkl",'wb')
    49         cPickle.dump(weights,f)
    50         f.close()
    51 
    52     def fineTune(self,trainset,labelset):
    53         rbm_output = self.calcRBMForward(trainset)
    54         MAXT,step,landa = 100,1,0.01
    55         self.softmax = SoftMax(MAXT,step,landa)
    56         self.softmax.process_train(rbm_output,labelset,self.ntype)
    57 
    58     def predict(self,x):
    59         rbm_output = self.calcRBMForward(x)
    60         return self.softmax.predict(rbm_output)
    61 
    62     def validate(self,testset,labelset):
    63         testnum = len(testset)
    64         correctnum = 0
    65         for i in range(testnum):
    66             x = testset[i]
    67             testtype = self.predict(x)
    68             orgtype = labelset[i]
    69             if testtype == orgtype:
    70                 correctnum += 1
    71         rate = float(correctnum)/testnum
    72         print "correctnum = %d, sumnum = %d" %(correctnum,testnum)
    73         print "Accuracy:%.2f" %(rate)
    74         return rate
    75 
    76 dbn = DBN(3,10,784,200)
    77 f = open('mnist.pkl', 'rb')
    78 training_data, validation_data, test_data = cPickle.load(f)
    79 training_inputs = [np.reshape(x, 784) for x in training_data[0]]
    80 data = np.array(training_inputs[:5000]).T
    81 training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
    82 vdata = np.array(training_inputs[:5000])
    83 if not os.path.exists('data/softmax.pkl'):  # Run twice
    84     dbn.pretrainRBM(data)
    85     dbn.fineTune(data.T,training_data[1][:5000])
    86 else:
    87     dbn.load_param("data/dbn.pkl","data/softmax.pkl")
    88     dbn.validate(vdata,validation_data[1][:5000])
    89 
    90 #RBM layer0 shape:(784L, 200L)
    91 #RBM layer1 shape:(200L, 100L)
    92 #RBM layer2 shape:(100L, 50L)
    93 #softmax parameter: (10L, 50L)
    94 #correctnum = 4357, sumnum = 5000
    95 #Accuracy:0.87
  • 相关阅读:
    关于AutoResetEvent和ManualResetEvent
    (转)使用 ODP.NET 和引用游标优化结果集
    胰腺
    SQL Cache Dependency
    败犬的远吠?
    吃亏和付出经常是必须的(转)
    AutoResetEvent 与 ManualResetEvent送花例子
    linux yum install
    SpringMVC+JPA+SpringData配置
    Spring AOP 实现原理
  • 原文地址:https://www.cnblogs.com/qw12/p/5906778.html
Copyright © 2011-2022 走看看