zoukankan      html  css  js  c++  java
  • TensorFlow CNN 測试CIFAR-10数据集


    本系列文章由 @yhl_leo 出品,转载请注明出处。
    文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311


    1 CIFAR-10 数据集

    CIFAR-10数据集是机器学习中的一个通用的用于图像识别的基础数据集。官网链接为:The CIFAR-10 dataset

    cifar10

    下载使用的版本号是:

    version

    将其解压后(代码中包括自己主动解压代码)。内容为:

    cifar10 data

    cifar10 data2

    2 測试代码

    測试代码发布在GitHub:yhlleo

    主要代码及作用:

    文件 作用
    cifar10_input.py 读取本地或者在线下载CIFAR-10的二进制文件格式数据集
    cifar10.py 建立CIFAR-10的模型
    cifar10_train.py 在CPU或GPU上训练CIFAR-10的模型
    cifar10_multi_gpu_train.py 在多个GPU上训练CIFAR-10的模型
    cifar10_eval.py 评估CIFAR-10模型的预測性能

    该部分的代码,介绍了怎样使用TensorFlow在CPU和GPU上训练和评估卷积神经网络(convolutional neural network, CNN)。

    3 相关网页及教程

    更加具体地介绍说明。请浏览网页:Convolutional Neural Networks

    中文站点极客学院也有该部分的汉译版:卷积神经网络

    代码源自tensorflow官网:tensorflow/models/image/cifar10

    4 代码改动说明

    GitHub发布代码相对源代码(本人的Tensorflow版本号还是0.5),主要进行了下面修正:

    • cifar10.py
    # indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
    indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1])
    
    # or
    indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])

    此处,源代码编译时会出现下面错误:

      ...
      File ".../cifar10.py", line 271, in loss
        indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
    TypeError: range() takes at least 2 arguments (1 given)
    • cifar10_input_test.py
    #self.a
    
  • 相关阅读:
    2月24日-寒假进度24
    2月23日-寒假学习进度23
    2月22日-寒假学习进度22
    2月21日-寒假学习进度21
    第一周冲刺意见汇总
    团队绩效评估
    团队工作第七天
    团队工作第六天
    团队工作第五天
    团队工作第四天
  • 原文地址:https://www.cnblogs.com/wzzkaifa/p/7245095.html
Copyright © 2011-2022 走看看