zoukankan      html  css  js  c++  java
  • Softmax回归

      使用参考上一篇随笔。详细介绍可参考http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92    

           

       

             

     1 # coding:utf8
     2 import numpy as np
     3 import cPickle
     4 import os
     5 
     6 
     7 class SoftMax:
     8     def __init__(self, MAXT=100, step=0.1, landa=0.01):
     9         self.MAXT = MAXT
    10         self.step = step
    11         self.landa = landa
    12 
    13     def load_theta(self, datapath):
    14         self.theta = cPickle.load(open(datapath, 'rb'))
    15 
    16     def process_train(self, x, y, typenum):
    17         costval = np.zeros(self.MAXT)
    18         print "Trainnum = %d, x.shape[1] = %d" % (x.shape[0], x.shape[1])
    19         self.theta = 0.001 * np.mat(np.random.randn(typenum, x.shape[1]))
    20         lastcostJ = 1000
    21         for m in range(self.MAXT):
    22             costs = np.zeros((typenum, x.shape[0]))
    23             grads = np.zeros((typenum, x.shape[1]))
    24             hval = self.h(x)
    25             for j in range(typenum):
    26                 jvalues = np.zeros((x.shape[0], x.shape[1]))
    27                 for i in range(x.shape[0]):
    28                     ptype = hval[i, j]
    29                     delta = (j == y[i])-ptype
    30                     costs[j, i] = (j == y[i])*np.log(ptype)
    31                     jvalues[i] = x[i] * delta
    32                 grads[j] = -np.mean(jvalues, axis=0) + self.landa * self.theta[j]  # 权重衰减项
    33             self.theta = self.theta - self.step * grads
    34             costJ = -np.sum(costs) / x.shape[0] + (self.landa / 2) * np.sum(np.square(self.theta))
    35             costval[m] = costJ
    36             if (costJ > lastcostJ):
    37                 print "costJ is increasing !!!"
    38                 break
    39             print "Loop(%d) cost = %.3f diff=%.4f" % (m, costJ, costJ - lastcostJ)
    40             lastcostJ = costJ
    41         if not os.path.exists('data'):
    42             os.makedirs('data')
    43         f = open("data/softmax.pkl", 'wb')
    44         cPickle.dump(self.theta, f)
    45         f.close()
    46 
    47     def h(self, x):
    48         m = np.exp(np.dot(x, self.theta.T))
    49         sump = np.sum(m, axis=1)
    50         return m / sump
    51 
    52     def predict(self, x):
    53         pv = self.h(x)
    54         return np.argmax(pv)
    55 
    56 
    57     def validate(self, testset, labelset):
    58         testnum = len(testset)
    59         correctnum = 0
    60         for i in range(testnum):
    61             x = testset[i]
    62             testtype = self.predict(x)
    63             orgtype = labelset[i]
    64             if testtype == orgtype:
    65                 correctnum += 1
    66         rate = float(correctnum) / testnum
    67         print "correctnum = %d, sumnum = %d" % (correctnum, testnum)
    68         print "Accuracy:%.2f" % (rate)
    69         return rate
    70 
    71 
    72 if __name__ == '__main__':
    73         f = open('mnist.pkl', 'rb')
    74         training_data, validation_data, test_data = cPickle.load(f)
    75         training_inputs = [np.reshape(x, 784) for x in training_data[0]]
    76         data = np.array(training_inputs[:5000])
    77         training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
    78         vdata = np.array(training_inputs[:5000])
    79         f.close()
    80         softmax = SoftMax()
    81         softmax.process_train(data, training_data[1][:5000], 10)
    82         softmax.validate(vdata, validation_data[1][:5000])
    83         # Accuracy:0.85
  • 相关阅读:
    谈谈一些有趣的CSS题目(十四)-- 纯 CSS 方式实现 CSS 动画的暂停与播放!
    Oracle 12c CDB PDB
    sqlplus 调试存储过程
    Oracle 存储过程A
    %notfound的理解——oracle存储过程 .
    ORA-04091: table xxx is mutating, trigger/function may not see it
    JQuery Div scrollTop ScrollHeight
    JQuery JSON Servlet
    div 移动
    HTML 三角符号
  • 原文地址:https://www.cnblogs.com/qw12/p/5910772.html
Copyright © 2011-2022 走看看