zoukankan      html  css  js  c++  java
  • 五分类-迁移学习

    由于移动云的图像分类效果不太理想,我将采用一种巧妙的方法——迁移学习来实现。即在预训练模型的基础上,采用101层的深度残差网络ResNet-101。

    迁移学习

    (1) 迁移学习简介

    什么是迁移学习呢?百度词条给出了一个简明的定义:迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。以我们的图像分类任务为例:

    假如任务A的任务是猫狗分类,任务B是要对老虎、狮子进行分类。可以发现,任务 A 和任务 B 存在大量的共享知识,比如这些动物都可以从毛发,体型,形态等方面进行辨别。因此在已经存在一个针对任务A训练好的模型前提下,在训练任务B的模型时,我们可以不从零开始训练,而是基于在任务 A 上获得的知识再进行训练。在这里,针对A任务已经训练好的模型参数称之为:预训练模型。

    我的训练模型代码

    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    import matplotlib.pyplot as plt
    from model import resnet101
    import tensorflow as tf
    import json
    import os
    import PIL.Image as im
    import numpy as np
    image_path =  "./Datasets/CIFAR100/"  # 数据集的路径
    train_dir = image_path + "train"
    validation_dir = image_path + "test"
    im_height = 224
    im_width = 224
    batch_size =64
    epochs = 5
    
    _R_MEAN = 123.68
    _G_MEAN = 116.78
    _B_MEAN = 103.94
    
    def pre_function(img):  # 图像预处理
        img = img - [_R_MEAN, _G_MEAN, _B_MEAN]
        return img
    train_image_generator = ImageDataGenerator(horizontal_flip=True,
                                               preprocessing_function=pre_function)
    train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
                                                               batch_size=batch_size,
                                                               shuffle=True,
                                                               target_size=(im_height, im_width),
                                                               class_mode='binary')
    total_train = train_data_gen.n  # 训练集样本总数
    validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function)
    val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
                                                                  batch_size=batch_size,
                                                                  shuffle=False,
                                                                  target_size=(im_height, im_width),
                                                                  class_mode='binary')
    # img, _ = next(train_data_gen)
    total_val = val_data_gen.n  # 验证集样本总数
    class_indices = train_data_gen.class_indices
    # 转换类别字典中键和值的位置
    inverse_dict = dict((val, key) for key, val in class_indices.items())
    # 将数字标签字典写入json文件:class_indices.json
    json_str = json.dumps(inverse_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    
    feature = resnet101(num_classes=3, include_top=False)
    feature.load_weights('./model-resnet/pretrain_weights.ckpt')  # 加载预训练模型
    feature.trainable = False  # 训练时冻结与训练模型参数
    feature.summary()  # 打印预训练模型参数
    model = tf.keras.Sequential([feature,
                                 tf.keras.layers.GlobalAvgPool2D(),
                                 tf.keras.layers.Dropout(rate=0.5),
                                 tf.keras.layers.Dense(1024),
                                 tf.keras.layers.Dropout(rate=0.5),
                                 tf.keras.layers.Dense(105),
                                 tf.keras.layers.Softmax()])
    # model.build((None, 224, 224, 3))
    model.summary()  # 打印增加层的参数
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),
                  loss='sparse_categorical_crossentropy',
                  metrics=['acc'])
                  
    history = model.fit(
        train_data_gen,
        steps_per_epoch=100,
        epochs=100,
        validation_data=val_data_gen,
        validation_steps=100)
    

    运行结果

  • 相关阅读:
    Linux安装Gradle
    MySQL 使用自增ID主键和UUID 作为主键的优劣比较详细过程(从百万到千万表记录测试)
    Websocket实现即时通讯
    Java线程池的使用
    Html5视频播放器-VideoJS+Audio标签实现视频,音频及字幕同步播放
    几种常用的认证机制
    Spring 接口参数加密传输
    Java 三种方式实现接口校验
    Spring AOP实现 Bean字段合法性校验
    RabbitMQ进程结构分析与性能调优
  • 原文地址:https://www.cnblogs.com/chenaiiu/p/14815323.html
Copyright © 2011-2022 走看看