zoukankan      html  css  js  c++  java
  • Tensorflow CIFAR10 (二分类)

    数据的下载:

    (共有三个版本:python,matlab,binary version 适用于C语言)

    http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

    http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz

    http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

    import os
    import _pickle as cPickle #python3
    import numpy as np

    import tensorflow as tf
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import imshow

    #dataset dir
    CIFAR_DIR = "./cifa-10-batches-py"

    def load_data(filename):
      '''read data from data file'''
      with open(filename,'rb') as f:
        data1 = cPickle.load(f,encoding='bytes')
        return data1[b'data'],data1[b'labels']

    class CifarData:
      def __init__(self,filenames,need_shuffle):
        all_data = []
        all_labels = []
        for filename in filenames:
          data,labels = load_data(filename)
          for item,label in zip(data,labels):
            if label in [0,1]:
              all_data.append(item)
              all_labels.append(label)

        self._data = np.vstack(all_data)
        self._data = self._data/127.5-1

        self._labels = np.hstack(all_labels)

        print("============================")
        print(self._data.shape)
        print(self._labels.shape)

        self._num_examples = self._data.shape[0]
        self._need_shuffle = need_shuffle
        self._indicator = 0
        if self._need_shuffle:
          self._shuffle_data()

      def _shuffle_data(self):
        #[0,1,2,3,4,5]->[5,3,2,4,0,1]
        p = np.random.permutation(self._num_examples)
        self._data = self._data[p]
        self._labels = self._labels[p]


      def next_batch(self,batch_size):
        """return batch_size examples as a batch."""
        end_indicator = self._indicator + batch_size
        if end_indicator>self._num_examples:
          if self._need_shuffle:
            self._shuffle_data()
            self._indicator = 0
            end_indicator = batch_size
          else:
            raise Exception("have no more examples...")

        if batch_size > self._num_examples:
          raise Exception("batch size is larger than all examples")

        batch_data = self._data[self._indicator:end_indicator]
        batch_labels = self._labels[self._indicator:end_indicator]
        self._indicator = end_indicator
        return batch_data,batch_labels

    x = tf.placeholder(tf.float32,[None,3072])
    #x = tf.placeholder(tf.float32,[None,32,32,3])
    #[None]
    y = tf.placeholder(tf.int64,[None])
    #y = tf.placeholder(tf.int64,[10])

    #(3071 ,1)
    w = tf.get_variable('w',[x.get_shape()[-1],1],
    initializer = tf.random_normal_initializer(0,1))

    #(1,)
    b = tf.get_variable('b',[1],
    initializer = tf.constant_initializer(0.0))

    #[None,3072]*[3072,1] = [None,1]
    y_ = tf.matmul(x,w)+b

    #[None,1]
    p_y_1 = tf.nn.sigmoid(y_)

    #[None,1]
    y_reshaped = tf.reshape(y,(-1,1))
    y_reshaped_float = tf.cast(y_reshaped,tf.float32)

    loss = tf.reduce_mean(tf.square(y_reshaped_float-p_y_1))

    #bool
    predict = p_y_1>0.5
    #[1,0,1,0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0]
    correct_prediction = tf.equal(tf.cast(predict,tf.int64),y_reshaped)

    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float64))

    with tf.name_scope('train_op'):
      train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

    train_filenames = [os.path.join(CIFAR_DIR,'data_batch_%d' % i) for i in range(1,6)]
    test_filenames = [os.path.join(CIFAR_DIR,'test_batch')]

    train_data = CifarData(train_filenames,True)
    test_data = CifarData(test_filenames,False)

    #batch_data,batch_labels = train_data.next_batch(10)
    #print("-----------------------------------------------------")
    #print(batch_data)
    #print(batch_labels)

    init = tf.global_variables_initializer()
    batch_size = 20
    train_steps = 100000

    test_steps = 100

    with tf.Session() as sess:
      sess.run(init)
      for i in range(train_steps):
        batch_data1,batch_labels1 = train_data.next_batch(batch_size)
        #batch_data,batch_labels = sess.run([])
        #print("-------------------1-------------------")
        #print(batch_data1)
        #print(batch_labels1)
        #print("-------------------2-------------------")
        loss_val,acc_val,_ = sess.run(
          [loss,accuracy,train_op],
          #[train_op,loss],
          feed_dict={
            x:batch_data1,
            y:batch_labels1
          }
        )

        if (i+1)%500 ==0:
          print('Train step:%d,loss:%4.5f,acc:%4.5f'
            %(i+1,loss_val,acc_val))

        if (i+1)%5000 ==0:
          test_data = CifarData(test_filenames,False)
          all_test_acc_val = []
          for j in range(test_steps):
            test_batch_data,test_batch_labels = test_data.next_batch(batch_size)
            test_acc_val = sess.run(
              [accuracy],
              feed_dict = {
                x:test_batch_data,
                y:test_batch_labels
              }
            )

            all_test_acc_val.append(test_acc_val)
          test_acc = np.mean(all_test_acc_val)
          print('Test Step:%d, acc:%4.5f'%(i+1,test_acc))

  • 相关阅读:
    手误【删库】 == 跑路,不存在的 Linux回收站
    大规模集群全网数据备份解决方案
    宝塔Nginx配置防盗链
    Markdown语法
    QFtp编程模型(二)
    Ubuntu驱动程序开发6-Linux内核启动与程序烧写
    Ubuntu下TFTP、NFS和SSH服务搭建
    ubuntu环境变量的三种设置方式
    QByteArray详解
    mysql的索引下推理解和实践
  • 原文地址:https://www.cnblogs.com/herd/p/10231790.html
Copyright © 2011-2022 走看看