zoukankan      html  css  js  c++  java
  • tensorflow之数据读取探究(2)

    tensorflow之tfrecord数据读取

    Tensorflow关于TFRecord格式文件的处理、模型的训练的架构为:
      1、获取文件列表、创建文件队列:http://blog.csdn.net/lovelyaiq/article/details/78711944(tfrecord格式,保存,读取)
      2、图像预处理:http://blog.csdn.net/lovelyaiq/article/details/78716325
      3、合成Batch:http://blog.csdn.net/lovelyaiq/article/details/78727189
      4、设计损失函数、梯度下降算法:http://blog.csdn.net/lovelyaiq/article/details/78616736

    - 首先了解tfrecord的格式;TensorFlow提供了TFRecord的格式统一管理存储数据

    # tf.train.Example 
    message Example{ 
        Features features = 1; 
    } 
    message Features{ 
        map<string,Features> feature = 1; 
    } 
    message Feature { 
        oneof kind { 
        BytesList bytes_list = 1; 
        FloateList float_list = 2; 
        Int64List int64_list = 3; 
       } 
    }

    从定义中可以看出tf.train.Example是以字典的形式存储数据格式,string为字典的key值,字典的属性值有三种类型:bytes、float、int64。接下来通过例子说明如果通过TFRecord保存和读取文件。保存和读取用到函数分别为:tf.python_io.TFRecordWriter和tf.TFRecordReader()。

    然后将原数据转换为tfrecord;(参考tensorflow/models/deeplab api)

    def _int64_list_feature(values):
      """Returns a TF-Feature of int64_list.
    
      Args:
        values: A scalar or list of values.
    
      Returns:
        A TF-Feature.
      """
      if not isinstance(values, collections.Iterable):
        values = [values]
    
      return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
    
    
    def _bytes_list_feature(values):
      """Returns a TF-Feature of bytes.
    
      Args:
        values: A string.
    
      Returns:
        A TF-Feature.
      """
      def norm2bytes(value):
        return value.encode() if isinstance(value, str) and six.PY3 else value
    
      return tf.train.Feature(
          bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
    
    
    def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
      """Converts one image/segmentation pair to tf example.
    
      Args:
        image_data: string of image data.
        filename: image filename.
        height: image height.
         image width.
        seg_data: string of semantic segmentation data.
    
      Returns:
        tf example of one image/segmentation pair.
      """
      return tf.train.Example(features=tf.train.Features(feature={
          'image/encoded': _bytes_list_feature(image_data),
          'image/filename': _bytes_list_feature(filename),
          'image/format': _bytes_list_feature(
              _IMAGE_FORMAT_MAP[FLAGS.image_format]),
          'image/height': _int64_list_feature(height),
          'image/width': _int64_list_feature(width),
          'image/channels': _int64_list_feature(3),
          'image/segmentation/class/encoded': (
              _bytes_list_feature(seg_data)),
          'image/segmentation/class/format': _bytes_list_feature(
              FLAGS.label_format),
      }))
    
    def _convert_dataset(dataset_split):
      """Converts the specified dataset split to TFRecord format.
    
      Args:
        dataset_split: The dataset split (e.g., train, val).
    
      Raises:
        RuntimeError: If loaded image and label have different shape, or if the
          image file with specified postfix could not be found.
      """
      image_files = _get_files('image', dataset_split) //得到文件列表
      label_files = _get_files('label', dataset_split)
    
      num_images = len(image_files)
      num_per_shard = int(math.ceil(num_images / float(_NUM_SHARDS)))
    
      image_reader = build_data.ImageReader('png', channels=3)
      label_reader = build_data.ImageReader('png', channels=1)
    
      for shard_id in range(_NUM_SHARDS): //保存_NUM_SHARDS个tfrecord文件
        shard_filename = '%s-%05d-of-%05d.tfrecord' % (
            dataset_split, shard_id, _NUM_SHARDS)
        output_filename = os.path.join(FLAGS.output_dir, shard_filename)
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_idx = shard_id * num_per_shard
          end_idx = min((shard_id + 1) * num_per_shard, num_images)
          for i in range(start_idx, end_idx):  //逐个文件进行读取转换;
            sys.stdout.write('
    >> Converting image %d/%d shard %d' % (
                i + 1, num_images, shard_id))
            sys.stdout.flush()
            # Read the image.
            image_data = tf.gfile.FastGFile(image_files[i], 'rb').read()
            height, width = image_reader.read_image_dims(image_data)
            # Read the semantic segmentation annotation.
            seg_data = tf.gfile.FastGFile(label_files[i], 'rb').read()
            seg_height, seg_width = label_reader.read_image_dims(seg_data)
            if height != seg_height or width != seg_
              raise RuntimeError('Shape mismatched between image and label.')
            # Convert to tf example.
            re_match = _IMAGE_FILENAME_RE.search(image_files[i])
            if re_match is None:
              raise RuntimeError('Invalid image filename: ' + image_files[i])
            filename = os.path.basename(re_match.group(1))
            example = build_data.image_seg_to_tfexample(
                image_data, filename, height, width, seg_data)
            tfrecord_writer.write(example.SerializeToString())
        sys.stdout.write('
    ')
        sys.stdout.flush()
    
    
    def main(unused_argv):
      # Only support converting 'train' and 'val' sets for now.
      for dataset_split in ['train', 'val']:
        _convert_dataset(dataset_split)

    然后在train时读取;(分三种,一种原始读取,一种tf.data.TFRecordDataset,一种用slim实现)分别参考:

    - 参考deeplab实现:

    import tensorflow as tf
    slim = tf.contrib.slim
    dataset = slim.dataset
    tfexample_decoder = slim.tfexample_decoder
    
    def get_dataset(dataset_name, split_name, dataset_dir):
      """Gets an instance of slim Dataset.
    
      Args:
        dataset_name: Dataset name.
        split_name: A train/val Split name.
        dataset_dir: The directory of the dataset sources.
    
      Returns:
        An instance of slim Dataset.
    
      Raises:
        ValueError: if the dataset_name or split_name is not recognized.
      """
      if dataset_name not in _DATASETS_INFORMATION:
        raise ValueError('The specified dataset is not supported yet.')
    
      splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
    
      if split_name not in splits_to_sizes:
        raise ValueError('data split name %s not recognized' % split_name)
    
      # Prepare the variables for different datasets.
      num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
      ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label
    
      file_pattern = _FILE_PATTERN
      file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
    
      # Specify how the TF-Examples are decoded.
      keys_to_features = {
          'image/encoded': tf.FixedLenFeature(
              (), tf.string, default_value=''),
          'image/filename': tf.FixedLenFeature(
              (), tf.string, default_value=''),
          'image/format': tf.FixedLenFeature(
              (), tf.string, default_value='jpeg'),
          'image/height': tf.FixedLenFeature(
              (), tf.int64, default_value=0),
          'image/width': tf.FixedLenFeature(
              (), tf.int64, default_value=0),
          'image/segmentation/class/encoded': tf.FixedLenFeature(
              (), tf.string, default_value=''),
          'image/segmentation/class/format': tf.FixedLenFeature(
              (), tf.string, default_value='png'),
      }
      items_to_handlers = {
          'image': tfexample_decoder.Image(
              image_key='image/encoded',
              format_key='image/format',
              channels=3),
          'image_name': tfexample_decoder.Tensor('image/filename'),
          'height': tfexample_decoder.Tensor('image/height'),
          'width': tfexample_decoder.Tensor('image/width'),
          'labels_class': tfexample_decoder.Image(
              image_key='image/segmentation/class/encoded',
              format_key='image/segmentation/class/format',
              channels=1),
      }
    
      decoder = tfexample_decoder.TFExampleDecoder(
          keys_to_features, items_to_handlers)
    
      return dataset.Dataset(
          data_sources=file_pattern,
          reader=tf.TFRecordReader,
          decoder=decoder,
          num_samples=splits_to_sizes[split_name],
          items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
          ignore_label=ignore_label,
          num_classes=num_classes,
          name=dataset_name,
          multi_label=True)

    - 再经过:

         ......
         data_provider = dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=num_readers,
          num_epochs=None if is_training else 1,
          shuffle=is_training)
      image, label, image_name, height, width = _get_data(data_provider,
                                                          dataset_split)
      if label is not None:
        if label.shape.ndims == 2:
          label = tf.expand_dims(label, 2)
        elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
          pass
        else:
          raise ValueError('Input label shape must be [height, width], or '
                           '[height, width, 1].')
    
        label.set_shape([None, None, 1])
      original_image, image, label = input_preprocess.preprocess_image_and_label(
          image,
          label,
          crop_height=crop_size[0],
          crop_width=crop_size[1],
          min_resize_value=min_resize_value,
          max_resize_value=max_resize_value,
          resize_factor=resize_factor,
          min_scale_factor=min_scale_factor,
          max_scale_factor=max_scale_factor,
          scale_factor_step_size=scale_factor_step_size,
          ignore_label=dataset.ignore_label,
          is_training=is_training,
          model_variant=model_variant)
      sample = {
          common.IMAGE: image,
          common.IMAGE_NAME: image_name,
          common.HEIGHT: height,
          common.WIDTH: width
      }
      if label is not None:
        sample[common.LABEL] = label
    
      if not is_training:
        # Original image is only used during visualization.
        sample[common.ORIGINAL_IMAGE] = original_image,
        num_threads = 1
    
      return tf.train.batch(
          sample,
          batch_size=batch_size,
          num_threads=num_threads,
          capacity=32 * batch_size,
          allow_smaller_final_batch=not is_training,
          dynamic_pad=True)

    - 主要讨论tensorflow的tfrecord读取方法;及slim读取数据;

    def read_data(is_training, split_name):
      file_pattern = '{}_{}.tfrecord'.format(args.data_name, split_name)
      tfrecord_path = os.path.join(args.data_dir,'records',file_pattern)
    
      if is_training:
        dataset = get_dataset(tfrecord_path) //通过slim方式读取tfrecord;
        image, gt_mask = extract_batch(dataset, args.batch_size, is_training)
      else:
        image, gt_mask = read_tfrecord(tfrecord_path) //通过原始方式读取tfrecord;
        image, gt_mask = preprocess.preprocess_image(image, gt_mask, is_training)
    return image, gt_mask

    1. 数据处理流程
    对于输入数据的处理,大体上流程都差不多,可以归结如下:
        将数据转为 TFRecord 格式的多个文件
        用 tf.train.match_filenames_once() 创建文件列表
        用 tf.train.string_input_producer() 创建输入文件队列,可以将输入文件顺序随机打乱
        用 tf.TFRecordReader() 读取文件中的数据
        用 tf.parse_single_example() 解析数据
        对数据进行解码及预处理
        用 tf.train.shuffle_batch() 将数据组合成 batch
        将 batch 用于训练

    2. 输入数据处理框架
    框架主要是三方面的内容:
        TFRecord 输入数据格式
        图像数据处理
        多线程输入数据处理

    3. reference:

  • 相关阅读:
    Oracle:解锁scott用户及设置密码
    js生成条形码
    返回头部效果
    密码强度
    事件委托小效果
    圆形导航效果
    进度条效果
    标题跟随效果
    随机抽人小效果
    点击创建效果
  • 原文地址:https://www.cnblogs.com/ranjiewen/p/10171149.html
Copyright © 2011-2022 走看看