zoukankan      html  css  js  c++  java
  • VGGnet——从TFrecords制作到网络训练

    作为一个小白中的小白,多折腾总是有好处的,看了入门书和往上一些教程,很多TF的教程都是从MNIST数据集入手教小白入TF的大门,都是直接import MNIST,然后直接构建网络,定义loss和optimizer,设置超参数,之后就直接sess.run()了,虽然操作流程看上去很简单,但如果直接给自己一堆图片,如何让tensorflow读取,如何喂入网络进行训练,这些都不清楚,所以作为小白,先从最简单的CNN——VGGnet入手吧,在网上随便下载了个数据集——GTSRB(因为这个数据集最小,下载快。。= =),下载下来的数据的前处理已经在另一篇博文数据图片处理介绍,这篇主要是TFrecords文件的制作和读取,我不是CS专业,研究方向也跟这个毫不相关,(刚入学时和导师约定好的计算机视觉方向现在被否了,一度让我想换导师,说来话长,此处省略一万字),一边要忙导师那边的东西,一边搞这个,可以说是很酸爽了 = =。。。这个程序折腾了近2个星期,最后可算是制服所有八阿哥,成功运行了,进入了所谓的“调参”环节,目前还很不理想,也许下面的程序还存在错误,但对于我这个小白来讲这次折腾已经学到很多东西了。

    下面进入正题。。。

    TFrecords文件是tensorflow读取数据的方式之一,主要用于数据较大的情况,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features),

    可以将自己的数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

    从TFRecords文件中读取数据, 可以使用tf.TFRecordReadertf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

    上面的内容来自:https://www.cnblogs.com/upright/p/6136265.html

    下面直接贴代码吧,有些部分并非原创,很多说明都写在代码中了(好吧,我承认我懒。。= =,这篇以后会更新的)


    VGGnet.py:

      1 # -*- coding: utf-8 -*-
      2 import tensorflow as tf
      3 import time
      4 import convert_TFrecords
      5 
      6 # 网络超参数
      7 learning_rate = 0.005
      8 batch_size = 300
      9 epoch = 20000
     10 display_step = 10
     11 
     12 # 网络参数
     13 Dropout = 0.75  # 失活的概率=1-Dropout
     14 n_inputs = 128 * 128 * 3  # 输入维度(img_size)
     15 n_classes = 43
     16 
     17 
     18 weights = {'w1': tf.Variable(tf.random_normal([3, 3, 3, 16])),
     19            'w2': tf.Variable(tf.random_normal([3, 3, 16, 16])),
     20            'w3': tf.Variable(tf.random_normal([3, 3, 16, 32])),
     21            'w4': tf.Variable(tf.random_normal([3, 3, 32, 32])),
     22            'w5': tf.Variable(tf.random_normal([3, 3, 32, 64])),
     23            'w6': tf.Variable(tf.random_normal([3, 3, 64, 64])),
     24            'w7': tf.Variable(tf.random_normal([3, 3, 64, 128])),
     25            'w8': tf.Variable(tf.random_normal([3, 3, 128, 128])),
     26            'w9': tf.Variable(tf.random_normal([3, 3, 128, 128])),
     27            'w10': tf.Variable(tf.random_normal([3, 3, 128, 128])),
     28            'wd1': tf.Variable(tf.random_normal([8*8*128, 4096])),
     29            'wd2': tf.Variable(tf.random_normal([1*1*4096, 4096])),
     30            'out': tf.Variable(tf.random_normal([4096, 43]))}  # 共43个类别
     31 
     32 biases = {'b1': tf.Variable(tf.random_normal([16])),
     33           'b2': tf.Variable(tf.random_normal([16])),
     34           'b3': tf.Variable(tf.random_normal([32])),
     35           'b4': tf.Variable(tf.random_normal([32])),
     36           'b5': tf.Variable(tf.random_normal([64])),
     37           'b6': tf.Variable(tf.random_normal([64])),
     38           'b7': tf.Variable(tf.random_normal([128])),
     39           'b8': tf.Variable(tf.random_normal([128])),
     40           'b9': tf.Variable(tf.random_normal([128])),
     41           'b10': tf.Variable(tf.random_normal([128])),
     42           'bd1': tf.Variable(tf.random_normal([4096])),
     43           'bd2': tf.Variable(tf.random_normal([4096])),
     44           'out': tf.Variable(tf.random_normal([43]))}
     45 
     46 
     47 def conv(name, input, W, b, strides=1, padding='SAME'):
     48     x = tf.nn.conv2d(input, W, strides=[1, strides, strides, 1], padding=padding)
     49     x = tf.nn.bias_add(x, b)
     50     return tf.nn.relu(x, name=name)
     51 
     52 
     53 # 输入应该是一个4维的张量,最后一维为batch_size,但这里构造的网络只按batch_size=1的情况来构造,即只考虑
     54 # 一个样本的情况,这是没有影响的,运行图的时候再指定batch_size
     55 def VGGnet(input, weights, biases, keep_prob):
     56     x = tf.reshape(input, shape=[-1, 128, 128, 3])   # -1处的值由batch_size决定
     57     conv1 = conv('conv1', x, weights['w1'], biases['b1'])
     58 
     59     conv2 = conv('conv2', conv1, weights['w2'], biases['b1'])
     60 
     61     pool1 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1')
     62 
     63     conv3 = conv('conv3', pool1, weights['w3'], biases['b3'])
     64 
     65     conv4 = conv('conv4', conv3, weights['w4'], biases['b4'])
     66 
     67     pool2 = tf.nn.max_pool(conv4, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2')
     68 
     69     conv5 = conv('conv5', pool2, weights['w5'], biases['b5'])
     70 
     71     conv6 = conv('conv6', conv5, weights['w6'], biases['b6'])
     72 
     73     conv7 = conv('conv7', conv6, weights['w7'], biases['b7'])
     74 
     75     pool3 = tf.nn.max_pool(conv7, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool3')
     76 
     77     conv8 = conv('conv8', pool3, weights['w8'], biases['b8'])
     78 
     79     conv9 = conv('conv9', conv8, weights['w9'], biases['b9'])
     80 
     81     conv10 = conv('conv10', conv9, weights['w10'], biases['b10'])
     82 
     83     pool4 = tf.nn.max_pool(conv10, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool4')
     84 
     85     fc1 = tf.reshape(pool4, [-1, weights['wd1'].get_shape().as_list()[0]])
     86     fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
     87 
     88     re1 = tf.nn.relu(fc1, 're1')
     89 
     90     drop1 = tf.nn.dropout(re1, keep_prob)
     91 
     92     fc2 = tf.reshape(drop1, [-1, weights['wd2'].get_shape().as_list()[0]])
     93     fc2 = tf.add(tf.matmul(fc2, weights['wd2']), biases['bd2'])
     94 
     95     re2 = tf.nn.relu(fc2, 're2')
     96 
     97     drop2 = tf.nn.dropout(re2, keep_prob)
     98 
     99     fc3 = tf.reshape(drop2, [-1, weights['out'].get_shape().as_list()[0]])
    100     fc3 = tf.add(tf.matmul(fc3, weights['out']), biases['out'])
    101 
    102     # print(fc3) 检查点
    103 
    104     # tf.nn.softmax_cross_entropy_with_logits函数已经进行了softmax处理!不必再加一层softmax(发现这个错误后,训练精度终于变得正常)
    105     # sm = tf.nn.softmax(fc3)
    106 
    107     return fc3
    108 
    109 
    110 # 注意下面的shape要和传入的tensor一致!使用mnist数据集时x的shape为[none, 28*28*1],是因为传入的数据是展开成行的
    111 x = tf.placeholder(tf.float32, [None, 128, 128, 3])
    112 y = tf.placeholder(tf.float32, [None, n_classes])
    113 dropout = tf.placeholder(tf.float32)
    114 
    115 pred = VGGnet(x, weights, biases, dropout)
    116 
    117 # 定义损失函数和优化器
    118 # 错误:Only call `softmax_cross_entropy_with_logits` with named arguments (labels=..., logits=...,),解决方法:参数要以关键字参数的形式传入
    119 # tf.nn.softmax_cross_entropy_with_logits先是对最后一层输出做一个softmax,然后求softmax向量里每个元素的这个值:y_i * log(yi)(y_i为实际值,yi为预测值),
    120 # tf.reduce_mean对每个元素上面的乘积求和再平均
    121 # 参考:https://blog.csdn.net/mao_xiao_feng/article/details/53382790
    122 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
    123 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
    124 
    125 # 评估函数
    126 # tf.argmax()返回每个向量最大元素的索引(axis=1),tf.equal()返回两个数是否相等(ture or false)
    127 # https://blog.csdn.net/qq575379110/article/details/70538051/
    128 # https://blog.csdn.net/uestc_c2_403/article/details/72232924
    129 correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    130 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    131 
    132 # init = tf.initialize_all_variables()
    133 batch_x, batch_y = convert_TFrecords.inputs(True, batch_size, epoch)
    134 
    135 with tf.Session() as sess:
    136     # sess.run(init)
    137     # 先执行初始化工作
    138     # 参考:https://blog.csdn.net/lujiandong1/article/details/53376802
    139     sess.run(tf.global_variables_initializer())
    140     sess.run(tf.local_variables_initializer())
    141     # sess.run(tf.initialize_all_variables())
    142 
    143     # 开启一个协调器
    144     coord = tf.train.Coordinator()
    145     # 使用start_queue_runners 启动队列填充
    146     threads = tf.train.start_queue_runners(sess, coord)
    147 
    148     try:
    149         step = 1
    150         while not coord.should_stop():
    151             # 获取每一个batch中batch_size个样本和标签
    152             # 原来下面这一句放在这个位置(改变这一句的位置后卡了几天的问及终于解决了):
    153             # batch_x, batch_y = convert_TFrecords.inputs(True, batch_size, epoch)
    154             # 结果程序卡住,无法运行,也不报错
    155             # 检查点:print('kaka')
    156 
    157             # print(batch_x)
    158             # print(batch_y)
    159             # print('okok') 检查点
    160             # 没有下面这句会报错:
    161             # The value of a feed cannot be a tf.Tensor object. Acceptable feed
    162             # values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.
    163             # 原以为是要用tensor.eval()将tensor转为np.array,但batch_x, batch_y = convert_TFrecords.inputs(True, batch_size, epoch)
    164             # 那时是放在sess里面,所以执行到tensor.eval()时一样会卡住不动
    165             b_x, b_y = sess.run([batch_x, batch_y])
    166             # print('haha') 检查点
    167             # 打印出tesor:默认值打印出3个参数  参考:https://blog.csdn.net/qq_34484472/article/details/75049179
    168             # print(b_x, b_y) 检查点
    169             # 这里原先喂入dict的tensor变量名不是b_x,b_y,而是和key名一样(也就是x,y),变量名与占位符名冲突,结果
    170             # 会报错:unhashable type: 'numpy.ndarray' error
    171             # 这个错误也有可能是其他原因引起,见:https://blog.csdn.net/wongleetion/article/details/80885648
    172             start = time.time()
    173             sess.run(optimizer, feed_dict={x: b_x, y: b_y, dropout: Dropout})
    174             if step % display_step == 0:
    175                 # 原来在feed_dict里关键字dropout打错成keep_prob了,结果弹出Cannot interpret feed_dict key
    176                 # as Tensor:Can not convert a float into a Tensor错误
    177                 # 参考https://blog.csdn.net/ice_pill/article/details/78567841
    178                 Loss, acc = sess.run([loss, accuracy], feed_dict={x: b_x, y: b_y, dropout: 1.0})
    179                 print('iter ' + str(step) + ', minibatch loss = ' +
    180                       '{: .6f}'.format(Loss) + ', training accuracy = ' + '{: .5f}'.format(acc))
    181                 # sess.run(tf.Print(b_y, [b_y], summarize=43))
    182                 print(b_y)
    183             print('iter %d, duration: %.2fs' % (step, time.time() - start))
    184             step += 1
    185     except tf.errors.OutOfRangeError:  # 如果读取到文件队列末尾会抛出此异常
    186         print("done! now lets kill all the threads……")
    187     finally:
    188         # 协调器coord发出所有线程终止信号
    189         coord.request_stop()
    190         print('all threads are asked to stop!')
    191     coord.join(threads)  # 把开启的线程加入主线程,等待threads结束
    192     print('all threads are stopped!')
    convert_TFrecords.py(TFrecords文件的制作和读取):
      1 # -*- coding: utf-8 -*-
      2 
      3 import os
      4 import tensorflow as tf
      5 from PIL import Image
      6 
      7 cur_dir = os.getcwd()
      8 
      9 # classes = ['test_file_dir', 'train_file_dir']
     10 train_set = os.path.join(cur_dir, 'train_file_dir')
     11 classes = os.listdir(train_set)
     12 
     13 
     14 # 制作二进制数据
     15 def create_record():
     16     print('processing...')
     17     writer = tf.python_io.TFRecordWriter('train.tfrecords')
     18     num_labels = len([name for name in classes])
     19     print('num of classes: %d' % num_labels)
     20     label = [0] * num_labels
     21     for index, name in enumerate(classes):
     22         class_path = os.path.join(train_set, name)
     23         label[index] = 1
     24         for img_name in os.listdir(class_path):
     25             img_path = os.path.join(class_path, img_name)
     26             img = Image.open(img_path)
     27             # img = img.resize((64, 64))
     28             img_raw = img.tobytes()  # 将图片转化为原生bytes
     29             # print(img_raw)
     30             # print(index,img_raw)
     31             # tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。
     32             # tfrecord文件包含了tf.train.Example 协议缓冲区(protocol buffer,协议缓冲区包含了特征 Features)。你可以写一段代码获取你的数据,
     33             # 将数据填入到Example协议缓冲区(protocol buffer),将协议缓冲区序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter class写
     34             # 入到TFRecords文件
     35             example = tf.train.Example(
     36                 # feature字典中每个key的值都是一个list,这些list是3种数据类型中的一种:FloatList, 或者ByteList,或者Int64List
     37                 # 参考https://blog.csdn.net/u012759136/article/details/52232266
     38                 # 参考https://blog.csdn.net/shenxiaolu1984/article/details/52857437
     39                features=tf.train.Features(feature={
     40                     # 设置图片在TFrecord文件中的标签(同一文件夹下标签一致),注意存储的是一个大小为num_label的list,而不是一个值!!
     41                     'label': tf.train.Feature(int64_list=tf.train.Int64List(value=label)), # label本来就是一个list,不用加中括号
     42                     # 设置图片在TFrecord文件中的值
     43                     'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
     44                }))
     45             writer.write(example.SerializeToString())
     46         label = [0] * num_labels
     47     writer.close()
     48     print('TFrecords file created successfully!!')
     49 
     50 
     51 # 读取二进制数据
     52 def read_and_decode(filename, num_epochs):
     53     # 根据文件名,顺序生成一个队列(如果shuffle=ture)
     54     filename_queue = tf.train.string_input_producer([filename], shuffle=True, num_epochs=num_epochs)
     55     print('qunide')
     56     reader = tf.TFRecordReader()
     57     _, serialized_example = reader.read(filename_queue)   # 返回文件名和文件
     58     features = tf.parse_single_example(serialized_example,
     59                                        features={
     60                                            # 这个函数不是很了解,原来在'label'里的shape为空([]),结果弹出错误:Key: label, Index: 0.  Number
     61                                            # of int64 values != expected.  Values size: 43 but output shape: []
     62                                            # 注意数据类型要和TFrecords文件中一致!!
     63                                            'label': tf.FixedLenFeature([43], tf.int64),
     64                                            'img_raw': tf.FixedLenFeature([], tf.string),     ##########
     65                                        })
     66 
     67     img = features['img_raw']
     68     # decode_raw()函数只能用于解码byteslist格式的数据
     69     img = tf.decode_raw(img, tf.uint8)
     70     img = tf.reshape(img, [128, 128, 3])
     71     img = tf.cast(img, tf.float32) * (1. / 255) - 0.5     # 规范化到±0.5之间
     72     label = features['label']
     73     # label = tf.reshape(label, [43])   ????不用这样做,原本存储的时候shape就是[43]
     74     label = tf.cast(label, tf.float32)    # 因为网络输出的pred值是float32类型的!!(?)
     75     print('label', label)
     76     print('image', img)
     77 
     78     return img, label
     79 
     80 
     81 def inputs(train, batch_size, num_epochs):
     82     print('qunide2')
     83     if not num_epochs:
     84         num_epochs = None
     85     filename = os.path.join(cur_dir, 'train.tfrecords' if train else 'test.tfrecords')  # 暂时先这样
     86 
     87     with tf.name_scope('input'):
     88         image, label = read_and_decode(filename, num_epochs)
     89         # print(image) 检查点
     90         # tf.train.shuffle_batch应该是从tf.train.string_input_producer生成的文件队列中先打乱再从中抽取组成batch,所以
     91         # 这个打乱后的队列容量和min_after_dequeue(应该是决定原有队列被抽取后的最小样本含量,决定被抽取后再填入的量)
     92         # 根据batch_size的不同会影响训练精度(因为再填充并打乱后很多之前网络没见过的样本会被送入,当所有训练数据都过一遍后,精度会提高),这是我的个人猜测
     93         images, sparse_labels = tf.train.shuffle_batch([image, label], batch_size=batch_size,
     94                                                         num_threads=2, capacity=3000,  # 线程数一般与处理器核数一样
     95                                                        # 但并不是线程越多越快,甚至更多的线程反而会使效率下降
     96                                                        # 参考:https://blog.csdn.net/lujiandong1/article/details/53376802
     97                                                        # https://blog.csdn.net/heiheiya/article/details/80967301
     98                                                        min_after_dequeue=2000)
     99         # print(images) 检查点
    100         return images, sparse_labels
    101     # 注意返回值的类型要与tf.placeholder()中的dtypes, shape都要相同!
    102 
    103 
    104 if __name__ == '__main__':
    105     create_record()

     虽然程序成功运行了,但训练精度很低,还有很多方面需要调整

    除了代码中提到的博文,还参考了下面的:

    https://blog.csdn.net/dcrmg/article/details/79780331

    https://blog.csdn.net/qq_30666517/article/details/79715045

    https://www.cnblogs.com/upright/p/6136265.html

    https://blog.csdn.net/tengxing007/article/details/56847828

    https://blog.csdn.net/ali197294332/article/details/78720309

    https://blog.csdn.net/ying86615791/article/details/73864381

  • 相关阅读:
    SQL优化总结之一
    web前端扩展性知识点
    canvas
    开动大脑js小案例(有空就更新的那种)
    本博客在手,jQuery无敌
    小程序整理(持续更新)
    样式初始化代码
    ajax中的async
    跨域问题解决
    ES6学习笔记(持续更新中)
  • 原文地址:https://www.cnblogs.com/tan-wm/p/9557176.html
Copyright © 2011-2022 走看看