zoukankan      html  css  js  c++  java
  • 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)

    一.迁移学习的概念

    什么是迁移学习呢?迁移学习可以由下面的这张图来表示:

     这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连接层和softmax/sigmoid),这样就可以节省训练的时间的到一个新训练的模型了!

    但是为什么可以这么做呢?

    二.为什么可以使用迁移学习?

    一般在图像分类的问题当中,卷积神经网络最前面的层用于识别图像最基本的特征,比如物体的轮廓,颜色,纹理等等,而后面的层才是提取图像抽象特征的关键,因此最好的办法是我们只需要保留卷积神经网络当中底层的权重,对顶层和新的分类器进行训练即可。那么在图像分类问题当中,我们如何使用迁移学习呢?一般使用迁移学习,也就是预训练神经网络的步骤如下;

    1.冻结预训练网络的卷积层权重

    2.置换旧的全连接层,换上新的全连接层和分类器

    3.解冻部分顶部的卷积层,保留底部卷积神经网络的权重

    4.同时对卷积层和全连接层的顶层进行联合训练,得到新的网络权重

    既然我们知道了迁移学习的基本特点,何不试试看呢?

    三.迁移学习的代码实现

    我们使用迁移学习的方法来进行猫狗图像的分类识别,猫猫的图像在我的文件夹里如下图所示:

    然后导包:

    import tensorflow as tf
    from tensorflow import keras
    import matplotlib.pyplot as plt
    import numpy as np
    import glob
    import os

    获取图片的路径,标签,制作batch数据,图片的路径我存放在了F盘下的train文件夹下,路径为:F://UNIVERSITY STUDY/AI/dataset/catdog/train/。

    代码如下:

    keras=tf.keras
    layers=tf.keras.layers
    #得到图片的所有label
    train_image_label=[int(p.split("\")[1]=='cat') for p in train_image_path ]
    
    
    #现在我们的jpg文件进行解码,变成三维矩阵
    def load_preprosess_image(path,label):
        #读取路径
        image=tf.io.read_file(path)
        #解码
        image=tf.image.decode_jpeg(image,channels=3)#彩色图像为3个channel
        #将图像改变为同样的大小,利用裁剪或者扭曲,这里应用了扭曲
        image=tf.image.resize(image,[360,360])
        #随机裁剪图像
        image=tf.image.random_crop(image,[256,256,3])
        #随机上下翻转图像
        image=tf.image.random_flip_left_right(image)
        #随机上下翻转
        image=tf.image.random_flip_up_down(image)
        #随机改变图像的亮度
        image=tf.image.random_brightness(image,0.5)
        #随机改变对比度
        image=tf.image.random_contrast(image,0,1)
        #改变数据类型
        image=tf.cast(image,tf.float32)
        #将图像进行归一化
        image=image/255
        #现在还需要对label进行处理,我们现在是列表[1,2,3],
        #需要变成[[1].[2].[3]]
        label=tf.reshape(label,[1])
        return image,label
    
    train_image_ds=tf.data.Dataset.from_tensor_slices((train_image_path,train_image_label))
    AUTOTUNE=tf.data.experimental.AUTOTUNE#根据计算机性能进行运算速度的调整
    train_image_ds=train_image_ds.map(load_preprosess_image,num_parallel_calls=AUTOTUNE)
    #现在train_image_ds就读取进来了,现在进行乱序和batchsize的规定
    BATCH_SIZE=32
    train_count=len(train_image_path)
    #现在设置batch和乱序
    train_image_ds=train_image_ds.shuffle(train_count).batch(BATCH_SIZE)
    train_image_ds=train_image_ds.prefetch(AUTOTUNE)#预处理一部分处理,准备读取
    
    imags,labels=iter(train_image_ds).next()#放到生成器里,单独取出数据
    plt.imshow(imags[30])

    显示出制作batch数据当中的猫猫图片:

     搭建网络架构,引入经典图像分类模型VGG16,同时调用VGG16预训练网络的权重。最后调整卷积层的最后三层为可训练的,也就是说顶层的卷积神经网路可以和全连接层分类器一起进行联合训练:

    conv_base=keras.applications.VGG16(weights='imagenet',include_top=False)
    #weights设置为imagenet表示使用imagebnet训练出来的权重,如果填写False表示不使用权重
    #仅适用网络架构,include_top表示是否使用用于分类的全连接层
    #我们在这个卷积层上添加全连接层和输出层即可
    model=keras.Sequential()
    model.add(conv_base)
    model.add(layers.GlobalAveragePooling2D())
    model.add(layers.Dense(512,activation='relu'))
    model.add(layers.Dense(1,activation='sigmoid'))
    
    conv_base.trainable=True#一共有19层
    for layer in conv_base.layers[:-3]:
        layer.trainable=False
    #从第一层到倒数第三层重新设置为是不可训练的,现在卷积的顶层已经解冻,开始联合训练
    
    #编译这个网络
    model.compile(optimizer=keras.optimizers.Adam(lr=0.001),
                  loss='binary_crossentropy',
                  metrics=['acc'])
    
    history=model.fit(
    train_image_ds,
    steps_per_epoch=train_count//BATCH_SIZE,
        epochs=1
    )

    仅仅训练一个epoch的结果如下所示;

    Train for 62 steps
    62/62 [==============================] - 469s 8s/step - loss: 0.6323 - acc: 0.6159

    一次迭代准确率已经达到了百分之六十。怎么样呢?你现在对迁移学习有一定的感觉了吗?

  • 相关阅读:
    Cocos2d-x游戏《雷电大战》开源啦!要源代码要资源快快来~~
    Tomcat部署项目时出错java.lang.IllegalStateException: ContainerBase.addChild: start:org.apache.catalina.Life
    PCA主成分分析Python实现
    C语言知识结构之二
    javascript中构造函数的返回值问题和new对象的过程
    poj 1694 An Old Stone Game 树形dp
    Android新技术学习——阿里巴巴免Root无侵入AOP框架Dexposed
    c++中vector向量几种情况的总结(向量指针,指针的向量)
    Hash分析
    三期_day05_Dao层的准备工作_II
  • 原文地址:https://www.cnblogs.com/geeksongs/p/13330728.html
Copyright © 2011-2022 走看看