zoukankan      html  css  js  c++  java
  • 深度学习进度05(TFRecords存储和读取、神经网络)

    TFRecords存储和读取:

    什么是TFRecords:

     Example结构解析:

     

     

    写:

     def write_to_tfrecords(self, image_batch, label_batch):
            """
            将样本的特征值和目标值一起写入tfrecords文件
            :param image:
            :param label:
            :return:
            """
            with tf.compat.v1.python_io.TFRecordWriter("cifar10.tfrecords") as writer:
                # 循环构造example对象,并序列化写入文件
                for i in range(100):
                    image = image_batch[i].tostring()
                    label = label_batch[i][0]
                    # print("tfrecords_image:
    ", image)
                    # print("tfrecords_label:
    ", label)
                    example = tf.train.Example(features=tf.train.Features(feature={
                        "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                    }))
                    # example.SerializeToString()
                    # 将序列化后的example写入文件
                    writer.write(example.SerializeToString())
    
            return None

    读:

     def read_tfrecords(self):
            """
            读取TFRecords文件
            :return:
            """
            # 1、构造文件名队列
            file_queue = tf.compat.v1.train.string_input_producer(["cifar10.tfrecords"])
    
            # 2、读取与解码
            # 读取
            reader = tf.compat.v1.TFRecordReader()
            key, value = reader.read(file_queue)
    
            # 解析example
            feature = tf.compat.v1.parse_single_example(value, features={
                "image": tf.compat.v1.FixedLenFeature([], tf.string),
                "label": tf.compat.v1.FixedLenFeature([], tf.int64)
            })
            image = feature["image"]
            label = feature["label"]
            print("read_tf_image:
    ", image)
            print("read_tf_label:
    ", label)
    
            # 解码
            image_decoded = tf.compat.v1.decode_raw(image, tf.uint8)
            print("image_decoded:
    ", image_decoded)
            # 图像形状调整
            image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel])
            print("image_reshaped:
    ", image_reshaped)
    
            # 3、构造批处理队列
            image_batch, label_batch = tf.compat.v1.train.batch([image_reshaped, label], batch_size=100, num_threads=2, capacity=100)
            print("image_batch:
    ", image_batch)
            print("label_batch:
    ", label_batch)
    
            # 开启会话
            with tf.compat.v1.Session() as sess:
    
                # 开启线程
                coord = tf.train.Coordinator()
                threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
    
                image_value, label_value = sess.run([image_batch, label_batch])
                print("image_value:
    ", image_value)
                print("label_value:
    ", label_value)
    
                # 回收资源
                coord.request_stop()
                coord.join(threads)
    
            return None

     

     神经网络:

     感知机:

     主要用途:

     softmax回归:

     

     交叉熵损失:

     

     

     

     

  • 相关阅读:
    P1182 数列分段 Section II 题解
    P3853 路标设置题解
    二分模板
    P2678 跳石头题解
    P2440 木材加工题解
    P1024 一元三次方程求解题解
    快速下载vscode的方法
    P1824 进击的奶牛题解
    P1873 砍树题解
    用户登录之asp.net cookie的写入、读取与操作
  • 原文地址:https://www.cnblogs.com/dazhi151/p/14440901.html
Copyright © 2011-2022 走看看