zoukankan      html  css  js  c++  java
  • TensorFlow(十八):从零开始训练图片分类模型

    (一):进入GitHub下载模型--》下载地址

    因为我们需要slim模块,所以将包中的slim文件夹复制出来使用。

    (1):在slim中新建images文件夹存放图片集

    (2):新建model文件夹用来放模型

    (3):在datasets文件夹中新建myimages.py文件

    # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================
    """Provides data for the flowers dataset.
    
    The dataset scripts used to create the dataset can be found at:
    tensorflow/models/slim/datasets/download_and_convert_flowers.py
    """
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import os
    import tensorflow as tf
    
    from datasets import dataset_utils
    
    slim = tf.contrib.slim
    
    _FILE_PATTERN = 'image_%s_*.tfrecord'
    
    SPLITS_TO_SIZES = {'train': 3500, 'test': 500}  # 这里根据自己的训练集内容进行修改
    
    _NUM_CLASSES = 5
    
    _ITEMS_TO_DESCRIPTIONS = {
        'image': 'A color image of varying size.',
        'label': 'A single integer between 0 and 4',
    }
    
    
    def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
      """Gets a dataset tuple with instructions for reading flowers.
    
      Args:
        split_name: A train/validation split name.
        dataset_dir: The base directory of the dataset sources.
        file_pattern: The file pattern to use when matching the dataset sources.
          It is assumed that the pattern contains a '%s' string so that the split
          name can be inserted.
        reader: The TensorFlow reader type.
    
      Returns:
        A `Dataset` namedtuple.
    
      Raises:
        ValueError: if `split_name` is not a valid train/validation split.
      """
      if split_name not in SPLITS_TO_SIZES:
        raise ValueError('split name %s was not recognized.' % split_name)
    
      if not file_pattern:
        file_pattern = _FILE_PATTERN
      file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
    
      # Allowing None in the signature so that dataset_factory can use the default.
      if reader is None:
        reader = tf.TFRecordReader
    
      keys_to_features = {
          'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
          'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
          'image/class/label': tf.FixedLenFeature(
              [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
      }
    
      items_to_handlers = {
          'image': slim.tfexample_decoder.Image(),
          'label': slim.tfexample_decoder.Tensor('image/class/label'),
      }
    
      decoder = slim.tfexample_decoder.TFExampleDecoder(
          keys_to_features, items_to_handlers)
    
      labels_to_names = None
      if dataset_utils.has_labels(dataset_dir):
        labels_to_names = dataset_utils.read_label_file(dataset_dir)
    
      return slim.dataset.Dataset(
          data_sources=file_pattern,
          reader=reader,
          decoder=decoder,
          num_samples=SPLITS_TO_SIZES[split_name],
          items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
          num_classes=_NUM_CLASSES,
          labels_to_names=labels_to_names)
    myimages.py

    (4):修改dataset_factory.py

    from datasets import myimages
    
    datasets_map = {
        'cifar10': cifar10,
        'flowers': flowers,
        'imagenet': imagenet,
        'mnist': mnist,
        'myimages':myimages, # 这一句为添加的内容
    }
    添加的内容

    (二):对图片进行处理,生成tfrecord格式的文件。

    import tensorflow as tf
    import os
    import random
    import math
    import sys
    
    
    #验证集数量
    _NUM_TEST = 500
    #随机种子
    _RANDOM_SEED = 0
    #数据块数目
    _NUM_SHARDS = 5
    #数据集路径
    DATASET_DIR = "C:/Users/FELIX/Desktop/tensor_study/slim/images/"
    #标签文件名字
    LABELS_FILENAME = ''.join([DATASET_DIR,'labels.txt'])
    
    #定义tfrecord文件的路径+名字
    def _get_dataset_filename(dataset_dir, split_name, shard_id):
        output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)
        return os.path.join(dataset_dir, output_filename)
    
    #判断tfrecord文件是否存在
    def _dataset_exists(dataset_dir):
        for split_name in ['train', 'test']:
            for shard_id in range(_NUM_SHARDS):
                #定义tfrecord文件的路径+名字
                output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
            if not tf.gfile.Exists(output_filename):
                return False
        return True
    
    #获取所有文件以及分类
    def _get_filenames_and_classes(dataset_dir):
        #数据目录
        directories = []
        #分类名称
        class_names = []
        for filename in os.listdir(dataset_dir):
            #合并文件路径
            path = os.path.join(dataset_dir, filename)
            #判断该路径是否为目录
            if os.path.isdir(path):
                #加入数据目录
                directories.append(path)
                #加入类别名称
                class_names.append(filename)
    
        photo_filenames = []
        #循环每个分类的文件夹
        for directory in directories:
            for filename in os.listdir(directory):
                path = os.path.join(directory, filename)
                #把图片加入图片列表
                photo_filenames.append(path)
    
        return photo_filenames, class_names
    
    def int64_feature(values):
        if not isinstance(values, (tuple, list)):
            values = [values]
        return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
    
    def bytes_feature(values):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
    
    def image_to_tfexample(image_data, image_format, class_id):
        #Abstract base class for protocol messages.
        return tf.train.Example(features=tf.train.Features(feature={
          'image/encoded': bytes_feature(image_data),
          'image/format': bytes_feature(image_format),
          'image/class/label': int64_feature(class_id),
        }))
    
    def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME):
        labels_filename = os.path.join(dataset_dir, filename)
        with tf.gfile.Open(labels_filename, 'w') as f:
            for label in labels_to_class_names:
                class_name = labels_to_class_names[label]
                f.write('%d:%s
    ' % (label, class_name))
    
    #把数据转为TFRecord格式
    def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
        assert split_name in ['train', 'test']
        #计算每个数据块有多少数据
        num_per_shard = int(len(filenames) / _NUM_SHARDS)
        with tf.Graph().as_default():
            with tf.Session() as sess:
                for shard_id in range(_NUM_SHARDS):
                    #定义tfrecord文件的路径+名字
                    output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
                    with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                        #每一个数据块开始的位置
                        start_ndx = shard_id * num_per_shard
                        #每一个数据块最后的位置
                        end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
                        for i in range(start_ndx, end_ndx):
                            try:
                                sys.stdout.write('
    >> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))
                                sys.stdout.flush()
                                #读取图片
                                image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() # 这里一定要rb否则会出现编码错误
                                #获得图片的类别名称
                                class_name = os.path.basename(os.path.dirname(filenames[i]))
                                #找到类别名称对应的id
                                class_id = class_names_to_ids[class_name]
                                #生成tfrecord文件
                                example = image_to_tfexample(image_data, b'jpg', class_id)
                                tfrecord_writer.write(example.SerializeToString())
                            except IOError as e:
                                print("Could not read:",filenames[i])
                                print("Error:",e)
                                print("Skip it
    ")
                                
        sys.stdout.write('
    ')
        sys.stdout.flush()
    
    
    if __name__ == '__main__':
        #判断tfrecord文件是否存在
        if _dataset_exists(DATASET_DIR):
            print('tfcecord文件已存在')
        else:
            #获得所有图片以及分类
            photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
            #把分类转为字典格式,类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0}
            class_names_to_ids = dict(zip(class_names, range(len(class_names))))
    
            #把数据切分为训练集和测试集
            random.seed(_RANDOM_SEED)
            random.shuffle(photo_filenames)
            training_filenames = photo_filenames[_NUM_TEST:]
            testing_filenames = photo_filenames[:_NUM_TEST]
    
            #数据转换
            _convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR)
            _convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR)
    
            #输出labels文件
            labels_to_class_names = dict(zip(range(len(class_names)), class_names))
            write_label_file(labels_to_class_names, DATASET_DIR)
    生成tfrecord

     (三):新建批处理文件,开始训练模型

    python C:/Users/FELIX/Desktop/tensor_study/slim/train_image_classifier.py ^
    --train_dir=C:/Users/FELIX/Desktop/tensor_study/slim/model ^
    --dataset_name=myimages ^
    --dataset_split_name=train ^
    --dataset_dir=C:/Users/FELIX/Desktop/tensor_study/slim/images ^
    --batch_size=10 ^
    --max_number_of_steps=10000 ^
    --model_name=inception_v3 ^
    pause
    
    
    
    注释:
    第一行表示运行训练文件,路径为全路径
    第二行表示模型存放位置
    第三行为创建的myimages文件名
    第四行为使用的训练集
    第五行为数据集所在的位置
    第六行为批次大小,默认为32,看个人GPU,我用10
    第七行为训练次数,默认无限次
    第八行为使用模型名称
    批处理文件
  • 相关阅读:
    Flask&&人工智能AI -- 12
    Flask&&人工智能AI -- 11
    Flask&&人工智能AI -- 10
    Flask&&人工智能AI -- 9
    Flask&&人工智能AI -- 8
    Flask&&人工智能AI -- 8 HTML5+ 初识,HBuilder,夜神模拟器,Webview
    Flask&&人工智能AI -- 7 MongoDB
    Flask&&人工智能AI -- 6 人工智能初识,百度AI,图灵机器人
    Flask&&人工智能AI --5 Flask-session、WTForms、数据库连接池、Websocket
    [转]八款开源Android游戏引擎
  • 原文地址:https://www.cnblogs.com/felixwang2/p/9241965.html
Copyright © 2011-2022 走看看