zoukankan      html  css  js  c++  java
  • tensorflow训练自己的数据集实现CNN图像分类1

    利用卷积神经网络训练图像数据分为以下几个步骤

    1. 读取图片文件
    2. 产生用于训练的批次
    3. 定义训练的模型(包括初始化参数,卷积、池化层等参数、网络
    4. 训练

    读取图片文件

     1 def get_files(filename):
     2     class_train = []
     3     label_train = []
     4     for train_class in os.listdir(filename):
     5         for pic in os.listdir(filename+train_class):
     6             class_train.append(filename+train_class+'/'+pic)
     7             label_train.append(train_class)
     8     temp = np.array([class_train,label_train])
     9     temp = temp.transpose()
    10     #shuffle the samples
    11     np.random.shuffle(temp)
    12     #after transpose, images is in dimension 0 and label in dimension 1
    13     image_list = list(temp[:,0])
    14     label_list = list(temp[:,1])
    15     label_list = [int(i) for i in label_list]
    16     #print(label_list)
    17     return image_list,label_list

      这里文件名作为标签,即类别(其数据类型要确定,后面要转为tensor类型数据)。

      然后将image和label转为list格式数据,因为后边用到的的一些tensorflow函数接收的是list格式数据。

    产生用于训练的批次

     1 def get_batches(image,label,resize_w,resize_h,batch_size,capacity):
     2     #convert the list of images and labels to tensor
     3     image = tf.cast(image,tf.string)
     4     label = tf.cast(label,tf.int64)
     5     queue = tf.train.slice_input_producer([image,label])
     6     label = queue[1]
     7     image_c = tf.read_file(queue[0])
     8     image = tf.image.decode_jpeg(image_c,channels = 3)
     9     #resize
    10     image = tf.image.resize_image_with_crop_or_pad(image,resize_w,resize_h)
    11     #(x - mean) / adjusted_stddev
    12     image = tf.image.per_image_standardization(image)
    13     
    14     image_batch,label_batch = tf.train.batch([image,label],
    15                                              batch_size = batch_size,
    16                                              num_threads = 64,
    17                                              capacity = capacity)
    18     images_batch = tf.cast(image_batch,tf.float32)
    19     labels_batch = tf.reshape(label_batch,[batch_size])
    20     return images_batch,labels_batch

      首先使用tf.cast转化为tensorflow数据格式,使用tf.train.slice_input_producer实现一个输入的队列。

      label不需要处理,image存储的是路径,需要读取为图片,接下来的几步就是读取路径转为图片,用于训练。

      CNN对图像大小是敏感的,第10行图片resize处理为大小一致,12行将其标准化,即减去所有图片的均值,方便训练。

      接下来使用tf.train.batch函数产生训练的批次。

      最后将产生的批次做数据类型的转换和shape的处理即可产生用于训练的批次。

    3 定义训练的模型

    (1)训练参数的定义及初始化

     1 def init_weights(shape):
     2     return tf.Variable(tf.random_normal(shape,stddev = 0.01))
     3 #init weights
     4 weights = {
     5     "w1":init_weights([3,3,3,16]),
     6     "w2":init_weights([3,3,16,128]),
     7     "w3":init_weights([3,3,128,256]),
     8     "w4":init_weights([4096,4096]),
     9     "wo":init_weights([4096,2])
    10     }
    11 
    12 #init biases
    13 biases = {
    14     "b1":init_weights([16]),
    15     "b2":init_weights([128]),
    16     "b3":init_weights([256]),
    17     "b4":init_weights([4096]),
    18     "bo":init_weights([2])
    19     }

      CNN的每层是y=wx+b的决策模型,卷积层产生特征向量,根据这些特征向量带入x进行计算,因此,需要定义卷积层的初始化参数,包括权重和偏置。其中第8行的参数形状后边再解释。

     (2)定义不同层的操作

     1 def conv2d(x,w,b):
     2     x = tf.nn.conv2d(x,w,strides = [1,1,1,1],padding = "SAME")
     3     x = tf.nn.bias_add(x,b)
     4     return tf.nn.relu(x)
     5 
     6 def pooling(x):
     7     return tf.nn.max_pool(x,ksize = [1,2,2,1],strides = [1,2,2,1],padding = "SAME")
     8 
     9 def norm(x,lsize = 4):
    10     return tf.nn.lrn(x,depth_radius = lsize,bias = 1,alpha = 0.001/9.0,beta = 0.75)

      这里只定义了三种层,即卷积层、池化层和正则化层

     (3)定义训练模型

     1 def mmodel(images):
     2     l1 = conv2d(images,weights["w1"],biases["b1"])
     3     l2 = pooling(l1)
     4     l2 = norm(l2)
     5     l3 = conv2d(l2,weights["w2"],biases["b2"])
     6     l4 = pooling(l3)
     7     l4 = norm(l4)
     8     l5 = conv2d(l4,weights["w3"],biases["b3"])
     9     #same as the batch size
    10     l6 = pooling(l5)
    11     l6 = tf.reshape(l6,[-1,weights["w4"].get_shape().as_list()[0]])
    12     l7 = tf.nn.relu(tf.matmul(l6,weights["w4"])+biases["b4"])
    13     soft_max = tf.add(tf.matmul(l7,weights["wo"]),biases["bo"])
    14     return soft_max

      模型比较简单,使用三层卷积,第11行使用全连接,需要对特征向量进行reshape,其中l6的形状为[-1,w4的第1维的参数],因此,将其按照“w4”reshape的时候,要使得-1位置的大小为batch_size,这样,最终再乘以“wo”时,最终的输出大小为[batch_size,class_num]

    (4)定义评估量

    1 def loss(logits,label_batches):
    2     cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=label_batches)
    3     cost = tf.reduce_mean(cross_entropy)
    4     return cost

      首先定义损失函数,这是用于训练最小化损失的必需量

    1 def get_accuracy(logits,labels):
    2     acc = tf.nn.in_top_k(logits,labels,1)
    3     acc = tf.cast(acc,tf.float32)
    4     acc = tf.reduce_mean(acc)
    5     return acc

      评价分类准确率的量,训练时,需要loss值减小,准确率增加,这样的训练才是收敛的。

    (5)定义训练方式

    1 def training(loss,lr):
    2     train_op = tf.train.RMSPropOptimizer(lr,0.9).minimize(loss)
    3     return train_op

      有很多种训练方式,可以自行去官网查看,但是不同的训练方式可能对应前面的参数定义不一样,需要另行处理,否则可能报错。

     4 训练

     1 def run_training():
     2     data_dir = 'C:/Users/wk/Desktop/bky/dataSet/'
     3     image,label = inputData.get_files(data_dir)
     4     image_batches,label_batches = inputData.get_batches(image,label,32,32,16,20)
     5     p = model.mmodel(image_batches)
     6     cost = model.loss(p,label_batches)
     7     train_op = model.training(cost,0.001)
     8     acc = model.get_accuracy(p,label_batches)
     9     
    10     sess = tf.Session()
    11     init = tf.global_variables_initializer()
    12     sess.run(init)
    13     
    14     coord = tf.train.Coordinator()
    15     threads = tf.train.start_queue_runners(sess = sess,coord = coord)
    16     
    17     try:
    18        for step in np.arange(1000):
    19            print(step)
    20            if coord.should_stop():
    21                break
    22            _,train_acc,train_loss = sess.run([train_op,acc,cost])
    23            print("loss:{} accuracy:{}".format(train_loss,train_acc))
    24     except tf.errors.OutOfRangeError:
    25         print("Done!!!")
    26     finally:
    27         coord.request_stop()
    28     coord.join(threads)
    29     sess.close()

      

  • 相关阅读:
    处理ORACLE死锁
    正则表达式 浮点数 整型
    Oracle数据类型number(m,n)
    chm文件打开无法正常显示内容
    安装PHP程序提示“include_path='.;c:php5pear'”错误的解决方法
    mysql数据导入数据报错(数据丢失)
    微信小程序如何与数据库交互?
    Eclipse 常用快捷键
    深入理解BodyTagSupport,包括SKIP_PAGE, EVAL_PAGE等
    JSP自定义标签Taglib实现过程重点总结
  • 原文地址:https://www.cnblogs.com/wktwj/p/7227544.html
Copyright © 2011-2022 走看看