zoukankan      html  css  js  c++  java
  • 【tensorflow】神经网络:自制数据集

    在实际应用中,我们常常需要自制数据集,解决本领域应用,而数据通常是图片或文字,需要做格式转换,才能在训练时使用。

    代码:

    import tensorflow as tf
    from PIL import Image
    import numpy as np
    import os
    
    # 训练用的输入特征和标签
    x_train_readpath = "class4/MNIST_FC/mnist_image_label/mnist_train_jpg_60000/"
    y_train_readpath = "class4/MNIST_FC/mnist_image_label/mnist_train_jpg_60000.txt"
    x_train_savapath = "class4/MNIST_FC/mnist_image_label/mnist_x_train.npy"
    y_train_savapath = "class4/MNIST_FC/mnist_image_label/mnist_y_train.npy"
    
    # 测试用的输入特征和标签
    x_test_readpath = "class4/MNIST_FC/mnist_image_label/mnist_test_jpg_10000/"
    y_test_readpath = "class4/MNIST_FC/mnist_image_label/mnist_test_jpg_10000.txt"
    x_test_savapath = "class4/MNIST_FC/mnist_image_label/mnist_x_test.npy"
    y_test_savapath = "class4/MNIST_FC/mnist_image_label/mnist_y_test.npy"
    
    # 读取输入特征和标签
    def generateData(x_path, y_path):
        f = open(y_path, "r")     # 以只读形式打开存放标签的文件
        contents = f.readlines()  # 按行读取文件中的所有数据
        f.close()                 # 关闭文件
    
        # 建立空列表,存放读出来的数据
        x, y = [], []
        for content in contents:
            # 数据存放形式为:文件名 标签
            # 以空格分开后,value[0]=文件名,value[1]=标签
            value = content.split()
    
            img_path = x_path + value[0]      # 拼接出训练图片完整路径
            img = Image.open(img_path)        # 读取图片
            img = np.array(img.convert("L"))  # 将图片变为 8位宽 灰度值的np.array格式
            img = img/255.0                   # 数据归一化
    
            x.append(img)                     # 保存读取出来的输入特征和标签
            y.append(value[1])
    
            print("loding:" + content)        # 打印状态提示
    
        x = np.array(x)         # [] -> np.array
        y = np.array(y)
        y = y.astype(np.int64)  # 将y中的数据统一设置为int64类型
    
        return x, y
    
    if os.path.exists(x_train_savapath) and os.path.exists(y_train_savapath) and os.path.exists(x_test_savapath) and os.path.exists(y_test_savapath):
        # 数据文件已存在,直接读取
        x_train_save = np.load(x_train_savapath)
        x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))
        y_train = np.load(y_train_savapath)
    
        x_test_save = np.load(x_test_savapath)
        x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
        y_test = np.load(y_test_savapath)
    else:
        # 数据文件不存在,生成数据文件
        x_train, y_train = generateData(x_train_readpath, y_train_readpath)
        x_test, y_test = generateData(x_test_readpath, y_test_readpath)
    
        x_train_save = np.reshape(x_train, (len(x_train), -1))
        x_test_save = np.reshape(x_test, (len(x_test), -1))
        np.save(x_train_savapath, x_train_save)
        np.save(y_train_savapath, y_train)
        np.save(x_test_savapath, x_test_save)
        np.save(y_test_savapath, y_test)
    
    # 声明神经网络结构
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(10, activation="softmax")
    ])
    
    # 配置训练方法(优化器,损失函数,评测方法)
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=5,
              validation_data=(x_test, y_test),
              validation_freq=1)
    
    # 打印网络结构和参数
    model.summary()
  • 相关阅读:
    JAVA中对null进行强制类型转换
    git 初次push
    svn还原与本地版本回退
    后台用map接收数据,报类型转换错误
    eclipse从svn导入静态文件
    APP项目下载及运行
    Yii2中如何使用CodeCeption
    开发资源整合
    工作流设计参考(包括PHP实现)
    PHP单元测试利器:PHPUNIT初探
  • 原文地址:https://www.cnblogs.com/bjxqmy/p/13535588.html
Copyright © 2011-2022 走看看