zoukankan      html  css  js  c++  java
  • Softmax回归(使用tensorflow)

     1 # coding:utf8
     2 import numpy as np
     3 import cPickle
     4 import os
     5 import tensorflow as tf
     6 
     7 class SoftMax:
     8     def __init__(self,MAXT=30,step=0.0025):
     9         self.MAXT = MAXT
    10         self.step = step
    11         
    12     def load_theta(self,datapath="data/softmax.pkl"):
    13         self.theta = cPickle.load(open(datapath,'rb'))
    14 
    15     def process_train(self,data,label,typenum=10,batch_size=500):
    16         batches =  data.shape[0] / batch_size
    17         valuenum=data.shape[1]
    18         if len(label.shape)==1:
    19             label=self.reshape_data(label,typenum)
    20         x = tf.placeholder("float", [None,valuenum])
    21         theta = tf.Variable(tf.zeros([valuenum,typenum]))
    22         y = tf.nn.softmax(tf.matmul(x,theta))
    23         y_ = tf.placeholder("float", [None, typenum])
    24         cross_entropy = -tf.reduce_sum(y_*tf.log(y))  #交叉熵
    25         train_step = tf.train.GradientDescentOptimizer(self.step).minimize(cross_entropy)
    26         init = tf.initialize_all_variables()
    27         sess = tf.Session()
    28         sess.run(init)
    29         for epoch in range(self.MAXT):
    30             cost_=[]
    31             for index in xrange(batches):
    32                 c_,_=sess.run([cross_entropy,train_step], feed_dict={ x: data[index * batch_size: (index + 1) * batch_size],
    33                 y_: label[index * batch_size: (index + 1) * batch_size]})
    34                 cost_.append(c_)
    35             if epoch % 5 == 0:
    36                 print(( 'epoch %i, minibatch %i/%i,averange cost is %f') %
    37                         (epoch,index + 1,batches,np.mean(cost_)))
    38         self.theta=sess.run(theta)
    39         if not os.path.exists('data/softmax.pkl'):
    40             f= open("data/softmax.pkl",'wb')
    41             cPickle.dump(self.theta,f)
    42             f.close()
    43         return self.theta
    44 
    45 
    46     def process_test(self,data,label,typenum=10):
    47         valuenum=data.shape[1]
    48         if len(label.shape)==1:
    49             label=self.reshape_data(label,typenum)
    50         x = tf.placeholder("float", [None,valuenum])
    51         theta = self.theta
    52         y = tf.nn.softmax(tf.matmul(x,theta))
    53         y_ = tf.placeholder("float", [None, typenum])
    54         init = tf.initialize_all_variables()
    55         sess = tf.Session()
    56         sess.run(init)
    57         correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    58         accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    59         print "Accuracy: ",sess.run(accuracy, feed_dict={x: data,y_: label})
    60 
    61     def h(self,x):
    62         m = np.exp(np.dot(x,self.theta))
    63         sump = np.sum(m,axis=1)
    64         return m/sump
    65 
    66     def predict(self,x):
    67         return np.argmax(self.h(x),axis=1)
    68 
    69     def reshape_data(self,label,typenum):
    70         label_=[]
    71         for yl_ in label:
    72             tl_=np.zeros(typenum)
    73             tl_[yl_]=1.0
    74             label_.append(tl_)
    75         return np.mat(label_)
    76 
    77 if __name__ == '__main__':
    78     f = open('mnist.pkl', 'rb')
    79     training_data, validation_data, test_data = cPickle.load(f)
    80     training_inputs = [np.reshape(x, 784) for x in training_data[0]]
    81     data = np.array(training_inputs)
    82     training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
    83     vdata = np.array(training_inputs)
    84     f.close()
    85 
    86     softmax = SoftMax()
    87     softmax.process_train(data,training_data[1])
    88     softmax.process_test(vdata,validation_data[1])  #Accuracy:  0.9269
    89     softmax.process_test(data,training_data[1])  #Accuracy:  0.92718
  • 相关阅读:
    微信推送给服务器的XML消息解析-springmvc 解析xml数据流
    request.getInputStream() 的两种解析方式
    微信的token验证
    springmvc 解析xml数据
    Spring 定时器 No qualifying bean of type [org.springframework.scheduling.TaskScheduler] is defined
    纯CSS实现图片
    Java线程池应用
    JavaScript 插件的书页翻转效果
    c语言中字符串函数的使用
    窗体显示类
  • 原文地址:https://www.cnblogs.com/qw12/p/5962430.html
Copyright © 2011-2022 走看看