zoukankan      html  css  js  c++  java
  • 迁移学习 colab 完整示例:fruits-360 数据集

    这里当前目录下已经有fruits-360这个数据集. 关于调用数据集的方法可以查看我另一篇文章.

    准备

    import tensorflow as tf
    import tensorflow.keras as keras
    from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img, ImageDataGenerator
    

    创建 Generator

    创建 ImageDataGenerator. 由于这个数据集足够大, 所以不需要进行 image augmentation.

    train_datagen = ImageDataGenerator(rescale=1./255)
    test_datagen = ImageDataGenerator(rescale=1./255)
    train_generator = train_datagen.flow_from_directory(
            "fruits-360/Training",
            target_size=(100, 100),
            batch_size=32,
            class_mode='categorical')
    
    validation_generator = test_datagen.flow_from_directory(
        "fruits-360/Test",
        target_size=(100, 100),
        batch_size=32,
        class_mode='categorical')
    

    运行后看到如下输出表示创建成功.

    Found 67692 images belonging to 131 classes.
    Found 22688 images belonging to 131 classes.

    模型

    这里使用的是 Xception 模型.

    from tensorflow.keras.applications.xception import preprocess_input
    from tensorflow.keras.applications.xception import decode_predictions
    from tensorflow.keras.applications.xception import Xception
    
    tf.keras.backend.clear_session()
    base_model = tf.keras.applications.Xception(
        weights='imagenet',  # Load weights pre-trained on ImageNet.
        input_shape=(100, 100, 3),
        include_top=False)  # Do not include the ImageNet classifier at the top.
    input_layer = tf.keras.Input(shape=(100, 100, 3))
    base_model.trainable = False
    # x = data_augmentation(input_layer)
    x = base_model(input_layer, training = False)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(64, activation = 'relu')(x)
    x = tf.keras.layers.Dropout(0.2)(x)  # Regularize with dropout
    output_layer = tf.keras.layers.Dense(131, activation = 'softmax')(x)
    model = tf.keras.Model(input_layer, output_layer)
    model.summary()
    

    base_model.trainable = False将会冻结 Xception 模型的权重, 在训练中不会被更新. 即使用已经训练好的权重.

    x = base_model(input_layer, training = False)training=False可以确保 base模型处于 Inference phase, 而不是Training phase.

    opt = tf.keras.optimizers.Adam(learning_rate=0.01)
    model.compile(loss='categorical_crossentropy', optimizer=opt, metrics = ['accuracy'])
    
    model.fit(train_generator, epochs=5, steps_per_epoch = 67692//32, validation_data=validation_generator)
    
  • 相关阅读:
    超经典~超全的jQuery插件大全
    如何用PHP做到页面注册审核
    php实现签到功能
    php中的实用分页类
    微信小程序,超能装的实例教程
    php之 常用的 流程管理
    php之 人员的权限管理(RBAC)
    php之简单的文件管理(基本功能)
    php最新学习-----文件的操作
    关于LAMP的配置之(虚拟机的安装、创建、配置)
  • 原文地址:https://www.cnblogs.com/yaos/p/14014112.html
Copyright © 2011-2022 走看看