zoukankan      html  css  js  c++  java
  • Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)

    #libsvm格式 数据 write libsvm

       

    #!/usr/bin/env python

    #coding=gbk

    # ==============================================================================

    # file gen-records.py

    # author chenghuige

    # date 2016-08-12 11:52:01.952044

    # Description

    # ==============================================================================

       

     

    from __future__ import absolute_import

    from __future__ import division

    #from __future__ import print_function

       

    import sys,os

       

    import tensorflow as tf

    import numpy as np

       

    flags = tf.app.flags

    FLAGS = flags.FLAGS

       

    _float_feature = lambda v: tf.train.Feature(float_list=tf.train.FloatList(value=v))

       

    _int_feature = lambda v: tf.train.Feature(int64_list=tf.train.Int64List(value=v))

       

    #how to store global info, using sequence example?

    def main(argv):

    writer = tf.python_io.TFRecordWriter(argv[2])

    for line in open(argv[1]):

    l = line.rstrip().split()

    label = int(l[0])

     

    start = 1

    num_features = 0

    if ':' not in l[1]:

    num_features = int(l[1])

    start += 1

     

    indexes = []

    values = []

     

    for item in l[start:]:

    index,value = item.split(':')

    indexes.append(int(index))

    values.append(float(value))

     

    example = tf.train.Example(features=tf.train.Features(feature={

    'label': _int_feature([label]),

    'num_features': _int_feature

    'index': _int_feature(indexes),

    'value': _float_feature(values)

    }))

    writer.write(example.SerializeToString())

       

    if __name__ == '__main__':

    tf.app.run()

       

       

    #libsvm格式 数据 read libsvm

       

    #!/usr/bin/env python

    #coding=gbk

    # ==============================================================================

    # file read-records.py

    # author chenghuige

    # date 2016-07-19 17:09:07.466651

    # Description

    # ==============================================================================

       

    #@TODO treat comment as sparse input ?

     

    from __future__ import absolute_import

    from __future__ import division

    #from __future__ import print_function

       

    import sys, os, time

    import tensorflow as tf

       

    import numpy as np

       

    flags = tf.app.flags

    FLAGS = flags.FLAGS

       

    flags.DEFINE_integer('batch_size', 5, 'Batch size.')

    flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to run trainer.')

    flags.DEFINE_integer('num_preprocess_threads', 12, '')

       

    MIN_AFTER_DEQUEUE = 10000

       

    def read(filename_queue):

    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)

    return serialized_example

       

    def decode(batch_serialized_examples):

    features = tf.parse_example(

    batch_serialized_examples,

    features={

    'label' : tf.FixedLenFeature([], tf.int64),

    'index' : tf.VarLenFeature(tf.int64),

    'value' : tf.VarLenFeature(tf.float32),

    })

       

    label = features['label']

    index = features['index']

    value = features['value']

       

    return label, index, value

       

    def batch_inputs(files, batch_size, num_epochs = None, num_preprocess_threads=1):

    """Reads input data num_epochs times.

    """

    if not num_epochs: num_epochs = None

       

    with tf.name_scope('input'):

    filename_queue = tf.train.string_input_producer(

    files, num_epochs=num_epochs)

       

    serialized_example = read(filename_queue)

    batch_serialized_examples = tf.train.shuffle_batch(

    [serialized_example],

    batch_size=batch_size,

    num_threads=num_preprocess_threads,

    capacity=MIN_AFTER_DEQUEUE + (num_preprocess_threads + 1) * batch_size,

    # Ensures a minimum amount of shuffling of examples.

    min_after_dequeue=MIN_AFTER_DEQUEUE)

       

    return decode(batch_serialized_examples)

       

    def read_records():

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

    # Input images and labels.

    tf_record_pattern = sys.argv[1]

    data_files = tf.gfile.Glob(tf_record_pattern)

    label, index, value = batch_inputs(data_files,

    batch_size=FLAGS.batch_size,

    num_epochs=FLAGS.num_epochs,

    num_preprocess_threads=FLAGS.num_preprocess_threads)

       

    # The op for initializing the variables.

    init_op = tf.group(tf.initialize_all_variables(),

    tf.initialize_local_variables())

       

    # Create a session for running operations in the Graph.

    #sess = tf.Session()

    sess = tf.InteractiveSession()

    #init_op = tf.initialize_all_variables()

    #self.session.run(init)

       

    # Initialize the variables (the trained variables and the

    # epoch counter).

    sess.run(init_op)

       

    # Start input enqueue threads.

    coord = tf.train.Coordinator()

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

       

    try:

    step = 0

    while not coord.should_stop():

    start_time = time.time()

    label_, index_, value_ = sess.run([label, index, value])

    print label_

    print index_

    print value_

    print index_[0]

    print index_[1]

    print index_[2]

    duration = time.time() - start_time

    step += 1

    except tf.errors.OutOfRangeError:

    print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))

    finally:

    # When done, ask the threads to stop.

    coord.request_stop()

       

    # Wait for threads to finish.

    coord.join(threads)

    sess.close()

       

       

    def main(_):

    read_records()

       

       

    if __name__ == '__main__':

    tf.app.run()

       

    #文本分类 text classification

    https://github.com/chenghuige/tensorflow-example

       

    using TfRecord only need small modification, like below, I will update the code in github soon.

       

    class SparseClassificationTrainer(object):

    """General framework for Sparse BinaryClassificationTrainer

       

    Sparse BinaryClassfiction will use sparse embedding look up trick

    see https://github.com/tensorflow/tensorflow/issues/342

    """

    def __init__(self, dataset = None, num_features = 0):

    if dataset is not None and type(dataset) != TfDataSet:

    self.labels = dataset.labels

    self.features = dataset.features

    self.num_features = dataset.num_features

    self.num_classes = dataset.num_classes

    else:

    self.features = SparseFeatures()

    self.num_features = num_features

    self.num_classes = None

       

    self.index_only = False

    self.total_features = self.num_features

       

    if type(dataset) != TfDataSet:

    self.sp_indices = tf.placeholder(tf.int64, name = 'sp_indices')

    self.sp_shape = tf.placeholder(tf.int64, name = 'sp_shape')

    self.sp_ids_val = tf.placeholder(tf.int64, name = 'sp_ids_val')

    self.sp_weights_val = tf.placeholder(tf.float32, name = 'sp_weights_val')

    self.sp_ids = tf.SparseTensor(self.sp_indices, self.sp_ids_val, self.sp_shape)

    self.sp_weights = tf.SparseTensor(self.sp_indices, self.sp_weights_val, self.sp_shape)

       

    self.X = (self.sp_ids, self.sp_weights)

    self.Y = tf.placeholder(tf.int32) #same as batch size

    else:

    self.X = (dataset.index, dataset.value)

    self.Y = dataset.label

     

    self.type = 'sparse'

       

       

       

    MIN_AFTER_DEQUEUE = 10000

    def read(filename_queue):

    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)

    return serialized_example

       

    def decode(batch_serialized_examples):

    features = tf.parse_example(

    batch_serialized_examples,

    features={

    'label' : tf.FixedLenFeature([], tf.int64),

    'index' : tf.VarLenFeature(tf.int64),

    'value' : tf.VarLenFeature(tf.float32),

    })

       

    label = features['label']

    index = features['index']

    value = features['value']

       

    return label, index, value

       

    def batch_inputs(files, batch_size, num_epochs=None, num_preprocess_threads=12):

    if not num_epochs: num_epochs = None

       

    with tf.name_scope('input'):

    filename_queue = tf.train.string_input_producer(

    files, num_epochs=num_epochs)

       

    serialized_example = read(filename_queue)

    batch_serialized_examples = tf.train.shuffle_batch(

    [serialized_example],

    batch_size=batch_size,

    num_threads=num_preprocess_threads,

    capacity=MIN_AFTER_DEQUEUE + (num_preprocess_threads + 1) * batch_size,

    # Ensures a minimum amount of shuffling of examples.

    min_after_dequeue=MIN_AFTER_DEQUEUE)

       

    return decode(batch_serialized_examples

    class TfDataSet(object):

    def __init__(self, data_files):

    self.data_files = data_files

    #@TODO now only deal sparse input

    self.features = SparseFeatures()

    self.label = None

       

    def build_read_graph(self, batch_size):

    tf_record_pattern = self.data_files

    data_files = tf.gfile.Glob(tf_record_pattern)

    self.label, self.index, self.value = batch_inputs(data_files, batch_size)

       

       

       

    def next_batch(self, sess):

    label, index, value = sess.run([self.label, self.index, self.value])

       

    trX = (index, value)

    trY = label

       

    return trX, trY

       

       

       

    trainset = melt.load_dataset(trainset_file, is_record=FLAGS.is_record)

    if FLAGS.is_record:

    trainset.build_read_graph(batch_size)

     

    step = 0

    while not coord.should_stop():

    #self.trainer.X, self.trainer.Y = trainset.next_batch(self.session)

    _, cost_, accuracy_ = self.session.run([self.train_op, self.cost, self.accuracy])

    if step % 100 == 0:

    print 'step:', step, 'train precision@1:', accuracy_,'cost:', cost_

    if step % 1000 == 0:

    pass

    step +=

  • 相关阅读:
    subprocess模块讲解
    正则
    logging日志模块
    2-30hashlib模块讲解
    json pickle复习 shelve模块讲解
    XML、PyYAML和configparser模块讲解
    os模块
    2-25sys模块和shutil模块讲解
    随机生成模块
    时间模块
  • 原文地址:https://www.cnblogs.com/rocketfan/p/5765979.html
Copyright © 2011-2022 走看看