zoukankan      html  css  js  c++  java
  • 深度学习-Tensorflow2.2-卷积神经网络{3}-卫星图像识别卷积综合实例(二分类)-13

    import tensorflow as tf
    import matplotlib.pyplot as plt
    %matplotlib inline
    import numpy as np
    import pathlib
    

    数据读取及预处理

    data_dir = "./2_class"# 文件路径
    
    data_root = pathlib.Path(data_dir)# 构建路径对象
    
    for item in data_root.iterdir(): # 对目录进行迭代查看文件路径及对象
        print(item)
    

    在这里插入图片描述

    all_image_path = list(data_root.glob("*/*"))#使用glob方法及正则表达式提取目录里面所有文件
    len(all_image_path) # 1400个数据
    

    在这里插入图片描述

    all_image_path[:3]# 通过切片查看前3个文件
    

    在这里插入图片描述

    all_image_path = [str(path) for path in all_image_path]# 使用str把路径变成一个实际的路径
    all_image_path[10:12]
    

    在这里插入图片描述

    import random
    random.shuffle(all_image_path)# 把内容乱序
    all_image_path[10:12]
    

    在这里插入图片描述

    image_count = len(all_image_path)
    image_count # 记录图片的张数
    

    在这里插入图片描述

    label_names = sorted (item.name for item in data_root.glob("*/")) # 提取分类名字
    label_names
    

    在这里插入图片描述

    # 编码airplane对应0, lake对应1
    label_to_index = dict((name,index) for index,name in enumerate(label_names))
    label_to_index
    

    在这里插入图片描述

    all_image_path[:3]
    

    在这里插入图片描述

    pathlib.Path("2_class\lake\lake_405.jpg").parent.name
    

    在这里插入图片描述

    all_image_label = [label_to_index[pathlib.Path(p).parent.name]for p in all_image_path]
    all_image_label[:5]
    all_image_path[:5]
    

    在这里插入图片描述

    import IPython.display as display
    
    index_to_label = dict((v,k) for k,v in label_to_index.items())
    index_to_label
    

    在这里插入图片描述

    读取和解码图片

    for n in range(3):
        image_index = random.choice(range(len(all_image_path)))
        display.display(display.Image(all_image_path[image_index]))
        print(index_to_label[all_image_label[image_index]])
        print()
    

    在这里插入图片描述

    # 对单张图片进行处理过程
    # 使用tf读取图片
    img_path = all_image_path[0]
    img_path
    

    在这里插入图片描述

    img_raw = tf.io.read_file(img_path)
    img_raw # 二进制的图片
    

    在这里插入图片描述

    # 解码图片
    img_tensor = tf.image.decode_image(img_raw)
    img_tensor.shape
    

    在这里插入图片描述

    img_tensor
    

    在这里插入图片描述

    img_tensor = tf.cast(img_tensor,tf.float32)# 转换数据类型为float32
    img_tensor
    

    在这里插入图片描述

    # 标准化
    img_tensor = img_tensor/255
    

    在这里插入图片描述

    定义函数对图片进行处理

    # 定义函数对图片进行处理
    def load_preprosess_image(img_paht):
        img_raw = tf.io.read_file(img_path) # 读取图片的路径
        img_tensor = tf.image.decode_jpeg(img_raw,channels=3)# 解码图片channels=3代表彩色图片
        img_tensor = tf.image.resize(img_tensor,[256,256]) #定义图片大小
        img_tensor = tf.cast(img_tensor,tf.float32) # 转化图片类型
        img = img_tensor/255 # 标准化
        return img
    

    在这里插入图片描述

    使用tf.data 构建图片输入管道

    # 构造tf.data对所有图片进行处理
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
    image_dataset = path_ds.map(load_preprosess_image)# 使用上面定义好的图片处理函数处理all_image_path中所有的图片
    label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
    

    在这里插入图片描述
    在这里插入图片描述

    # 合并
    dataset = tf.data.Dataset.zip((image_dataset,label_dataset))
    

    在这里插入图片描述

    # 划分测试集与训练集
    test_count = int(image_count*0.2)
    train_count = image_count-test_count
    

    在这里插入图片描述

    train_dataset = dataset.skip(test_count) # skip 跳过测试集的张数
    test_dataset = dataset.take(test_count)
    BATCH_SIZE = 32# 每次训练32张
    
    train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(BATCH_SIZE)
    

    在这里插入图片描述

    test_dataset = test_dataset.batch(BATCH_SIZE)
    

    建立模型

    # 增加BN层
    #建立模型
    model = tf.keras.Sequential() # 顺序模型
    model.add(tf.keras.layers.Conv2D(64,(3,3),input_shape=(256,256,3)))# 添加一个卷积层
    model.add(tf.keras.layers.BatchNormalization()) # 批标准化
    model.add(tf.keras.layers.Activation("relu")) # 激活层
    
    model.add(tf.keras.layers.Conv2D(64,(3,3)))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation("relu"))
    
    model.add(tf.keras.layers.MaxPooling2D())
    model.add(tf.keras.layers.Conv2D(128,(3,3)))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation("relu"))
    
    model.add(tf.keras.layers.MaxPooling2D())
    model.add(tf.keras.layers.Conv2D(256,(3,3)))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation("relu"))
    
    model.add(tf.keras.layers.MaxPooling2D())
    model.add(tf.keras.layers.Conv2D(512,(3,3)))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation("relu"))
    
    model.add(tf.keras.layers.MaxPooling2D())
    model.add(tf.keras.layers.Conv2D(1024,(3,3)))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation("relu"))
    
    model.add(tf.keras.layers.GlobalAveragePooling2D()) # 全局池化
    model.add(tf.keras.layers.Dense(1024))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation("relu"))
    
    model.add(tf.keras.layers.Dense(256))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation("relu"))
    
    model.add(tf.keras.layers.Dense(1,activation="sigmoid"))#二分类使用sigmoid激活
    
    model.summary()
    

    在这里插入图片描述

    # 编译模型
    model.compile(optimizer="adam",
                 loss="binary_crossentropy",
                 metrics=["acc"])
    
    steps_per_epoch = train_count//BATCH_SIZE
    validation_steps = test_count//BATCH_SIZE # 步数
    
    
    # 训练模型
    history = model.fit(train_dataset,epochs=30,
                        steps_per_epoch=steps_per_epoch,
                        validation_data=test_dataset,
                        validation_steps=validation_steps)
    

    在这里插入图片描述

  • 相关阅读:
    oracle 数据库服务名怎么查
    vmware vsphere 6.5
    vSphere虚拟化之ESXi的安装及部署
    ArcMap中无法添加ArcGIS Online底图的诊断方法
    ArcGIS中字段计算器(高级计算VBScript、Python)
    Bad habits : Putting NOLOCK everywhere
    Understanding the Impact of NOLOCK and WITH NOLOCK Table Hints in SQL Server
    with(nolock) or (nolock)
    What is “with (nolock)” in SQL Server?
    Changing SQL Server Collation After Installation
  • 原文地址:https://www.cnblogs.com/gemoumou/p/14186269.html
Copyright © 2011-2022 走看看