zoukankan      html  css  js  c++  java
  • 【5】TensorFlow光速入门-图片分类完整代码

    本文地址:https://www.cnblogs.com/tujia/p/13862364.html

    系列文章:

    【0】TensorFlow光速入门-序

    【1】TensorFlow光速入门-tensorflow开发基本流程

    【2】TensorFlow光速入门-数据预处理(得到数据集)

    【3】TensorFlow光速入门-训练及评估

    【4】TensorFlow光速入门-保存模型及加载模型并使用

    【5】TensorFlow光速入门-图片分类完整代码

    【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

    【7】TensorFlow光速入门-总结

    一、完整代码

    import pathlib
    import random
    import tensorflow as tf
    from tensorflow import keras
    import numpy as np
    import IPython.display as display
    import matplotlib.pyplot as plt
    
    # 读取文件夹图片数据
    data_path = '/tf/datasets/wnw'
    all_image_paths = []
    all_image_labels = []
    label_names = []
    data_root = pathlib.Path(data_path)
    i = 0
    for item in data_root.iterdir():
        label_names.append(item.name)
        for image in item.iterdir():
            all_image_paths.append(str(image))
            all_image_labels.append(i)
        i = i + 1
    print(label_names)
    print(len(all_image_paths))
    print(len(all_image_labels))
    
    # 抽样检查
    image_count = len(all_image_paths)
    for x in range(5):
        i = random.randint(0, image_count-1);
        image_path = all_image_paths[i]
        display.display(display.Image(image_path, width=100, height=100))
        print(label_names[all_image_labels[i]])
    
    # 图片 转 tensor3D 格式
    def load_and_preprocess_image(path):
        # 文件 转 tensor
        image = tf.io.read_file(path)
        # 普通 tensor 转 图片tensor,channels 为颜色通道,1表示灰图
        image = tf.image.decode_jpeg(image, channels=1)
        # 缩放图片尺寸为 100*100
        image = tf.image.resize(image, [100, 100])
        # 颜色的数值范围是0-255,所以 image/255,进一步将图片tensor数据数值范围缩到 0-1
        image /= 255
        return image
    
    # 批量处理图片
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
    image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    
    # 抽样检查
    for i, image in enumerate(image_ds.take(5)):
        plt.imshow(image.numpy().squeeze(), cmap=plt.cm.gray_r)
        plt.grid(False)
        plt.xlabel(label_names[all_image_labels[i]])
        plt.show()
    
    # label 数据集
    label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
    
    # 打包图片及其label
    image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
    
    # 打乱数据
    image_count = len(all_image_paths)
    ds = image_label_ds.shuffle(buffer_size=image_count)
    ds = ds.repeat()
    ds = ds.batch(32)
    ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    print(ds)
    
    # 模型初始化(配置神经网络层)
    model = keras.Sequential([
        # 展平数据,输入类型要和数据集保持一致,我这里是100*100的灰图
        keras.layers.Flatten(input_shape=(100, 100, 1)),
        # 第二层是神经元
        keras.layers.Dense(128, activation='relu'),
        # 第三层的参数很重要,2表示分两类,如果要分5类就传5,10类就传10
        keras.layers.Dense(2, activation='softmax')
    ])
    
    # 优化器、损失函数及指标
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    # 训练 100 次
    model.fit(ds, epochs=100, steps_per_epoch=10)
    
    # 评估
    test_loss, test_acc = model.evaluate(ds, verbose=2, steps=10)
    
    # 预测
    predictions = model.predict(ds, steps=10)
    label = np.argmax(predictions[0])
    print(label_names[label])
    
    # 保存模型
    model.save('/tf/saved_model/wnw')

    二、jupyter 笔记本

    附件下载: wnw.ipynb

    解压缩后,上传 wnw.ipynb 到 tensorflow-tutorials 目录就行了

    参考【2】TensorFlow光速入门-数据预处理(得到数据集) 准备好图片数据后,直接运行 wnw.ipynb 就行了

    注:图片数据需为jpg格式,不能用png或gif格式的,否则会报错~~

    下一节,我们来看一下训练好的模型如果在 web 项目中应用:

    【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

    本文链接:https://www.cnblogs.com/tujia/p/13862364.html


     完。

  • 相关阅读:
    Data Wrangling文摘:Non-tidy-data
    Data Wrangling文摘:Tideness
    Python文摘:Mixin 2
    Python文摘:Mixin
    Python文摘:Python with Context Managers
    Python学习笔记11
    SQL学习笔记9
    SQL学习笔记8
    SQL学习笔记7
    Python学习笔记10:内建结构
  • 原文地址:https://www.cnblogs.com/tujia/p/13862364.html
Copyright © 2011-2022 走看看