zoukankan      html  css  js  c++  java
  • 【TensorFlow】使用TensorFlow Lite Model Maker训练模型

    基础LENET5模型

    import tensorflow as tf
    from keras.preprocessing.image import ImageDataGenerator
    from tensorflow.keras.optimizers import RMSprop
    
    train_dir="E://PycharmProjects//LearingImageData//DataSets//train"
    validation_dir="E://PycharmProjects//LearingImageData//DataSets//var"
    
    train_datagen= ImageDataGenerator(
        rescale=1./255,
        rotation_range=40,#随机旋转,-40°到40°
        width_shift_range=0.2,#平移宽度
        height_shift_range=0.2,#平移高度
        shear_range=0.2,#错切0-1
        zoom_range=0.2,#放大倍数
        horizontal_flip=True,#翻转
        fill_mode='nearest')   #统一化处理
    train_generator=train_datagen.flow_from_directory(
        train_dir,
        target_size=(300,300),
        batch_size=128,#每次抽出128张图片进行训练
        class_mode='categorical'#如果只有两类,用binary,多个类别用categorical
    )
    
    test_datagen= ImageDataGenerator(rescale=1./255)   #统一化处理
    validation_generator=test_datagen.flow_from_directory(
        validation_dir,
        target_size=(300,300),
        batch_size=128,#每次抽出128张图片进行训练
        class_mode='binary'#如果只有两类,用binary,多个类别用categorical
    )
    
    #搭建一个简单的CNN模型
    model= tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(16,(3,3),activation='relu',
                               input_shape=(300,300,3)),
        tf.keras.layers.MaxPool2D(2,2),
        tf.keras.layers.Conv2D(32,(3,3),activation='relu'),#每个卷积层后面都有一个最大池化层
        tf.keras.layers.MaxPool2D(2,2),
        tf.keras.layers.Conv2D(64,(3,3),activation='relu'),
        tf.keras.layers.MaxPool2D(2,2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512,activation='relu'),
        tf.keras.layers.Dense(1,activation='sigmoid')#3类,0,1,2
    ])
    
    #model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
     #             optimizer=RMSprop(lr=0.001),
      #            metrics=['accuracy'])
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    model.summary()
    history=model.fit(
        train_generator,
        epochs=20,
        validation_data=validation_generator,
        steps_per_epoch=2,
        validation_steps=1,
        verbose=2)
    model.save("E://PycharmProjects//LearingImageData//DataSets//finalModel")

    基于MobileNet的改进模型(自动保存准确度大于0.8的模型并转化成tflite)

    import tensorflow as tf
    import keras
    import keras_applications.mobilenet_v2 as mobilenet
    import numpy as np
    
    
    def get_model():
        # 获取base-model
        base_model = mobilenet.MobileNetV2(include_top=False,
                                           weights='imagenet',
                                           input_shape=(224, 224, 3),
                                           backend=keras.backend,
                                           layers=keras.layers,
                                           models=keras.models,
                                           utils=keras.utils)
        # 冻结base-model所有参数
        for layer in base_model.layers:
            layer.trainable = False
        # global average pooling & classification layer
        gm = keras.layers.GlobalAveragePooling2D()(base_model.output)
        fc = keras.layers.Dense(units=3,
                                activation='softmax',
                                kernel_regularizer=tf.keras.regularizers.l1(0.005))(gm)
        # define model
        model = keras.models.Model(base_model.input, fc, name='zc_classifier')
        return model
    
    
    def get_train_input():
        train_image = keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, samplewise_center=False,
                                                                   featurewise_std_normalization=False,
                                                                   samplewise_std_normalization=False, zca_whitening=False,
                                                                   zca_epsilon=1e-06, rotation_range=0.1,
                                                                   width_shift_range=0.1, height_shift_range=0.1,
                                                                   shear_range=0.0,
                                                                   zoom_range=0.1, channel_shift_range=0.0,
                                                                   fill_mode='nearest', cval=0.0, horizontal_flip=True,
                                                                   vertical_flip=False, rescale=None,
                                                                   preprocessing_function=lambda
                                                                       x: mobilenet.preprocess_input(x,
                                                                                                     data_format="channels_last"),
                                                                   data_format="channels_last")
    
        train_gen = train_image.flow_from_directory(r'E:PycharmProjectsLearingImageDataDataSets	rain',
                                                    target_size=(224, 224),
                                                    batch_size=32,
                                                    interpolation="bilinear")
        for img, lab in train_gen:
            classes = np.argmax(lab, axis=-1)
            lab[classes == 0] = [0.85, 0.075, 0.075]
            lab[classes == 1] = [0.05, 0.8, 0.15]
            lab[classes == 2] = [0.05, 0.15, 0.8]
            yield img, lab
    
    
    def get_test_input():
        test_image = keras.preprocessing.image.ImageDataGenerator(preprocessing_function=
                                                                  lambda x: mobilenet.preprocess_input(x,
                                                                                                       data_format="channels_last"),
                                                                  data_format="channels_last", )
        test_gen = test_image.flow_from_directory(r'E:PycharmProjectsLearingImageDataDataSetsvar',
                                                  target_size=(224, 224),
                                                  batch_size=1,
                                                  shuffle=False,
                                                  interpolation="bilinear", )
        return test_gen
    
    
    def custom_loss(y_true, y_pred):
        class_weight = [0.33, 0.33, 0.34]  #配权重 数量越多,权重越低
        # classes = tf.argmax(y_true, axis=1)
        loss = 0.0
        for i in range(3):
            loss += (y_true[:, i] * tf.math.log(y_pred[:, i]) +
                     (1 - y_true[:, i]) * tf.math.log(1 - y_pred[:, i])) * class_weight[i]
        return -tf.reduce_mean(loss)
    
    
    def train(model: tf.keras.Model, train_gen, test_gen, save_function):
        adam = keras.optimizers.Adam(0.01)#
        model.compile(loss=custom_loss, optimizer=adam, metrics=['accuracy'])
        model.fit_generator(train_gen,
                            epochs=8,
                            steps_per_epoch=6,
                            validation_data=test_gen,
                            callbacks=[save_function])
        model.save('./zc.h5')
    
    
    def test(model: tf.keras.Model, test_gen):
        # import numpy as np
        # adam = keras.optimizers.Adam(0.001)
        # model.compile(loss="categorical_crossentropy", optimizer=adam, metrics=['accuracy'])
        # model.load_weights('./zc.h5')
        print(model.evaluate_generator(test_gen, verbose=1))
        # for img, lab in test_gen:
        #     print(np.argmax(lab[0]), np.argmax(model.predict(img)))
    
    
    class Save(keras.callbacks.Callback):
        def __init__(self):
            super(Save, self).__init__()
            self.max_acc = 0.0
    
        def on_epoch_begin(self, epoch, logs=None):
            pass
    
        def on_epoch_end(self, epoch, logs=None):
            self.val_acc = logs["accuracy"]
            if epoch != 0:
                if self.val_acc > self.max_acc and self.val_acc > 0.8:
                    model.save("kears_model_" + str(epoch) + "_acc=" + str(self.val_acc) + ".h5")
                    self.max_acc = self.val_acc
                    converter = tf.lite.TFLiteConverter.from_keras_model(model)
                    converter.experimental_new_converter = True
                    tflite_model = converter.convert()
                    open("converted_model_"+str(self.val_acc)+".tflite", "wb").write(tflite_model)
    
    
    if __name__ == "__main__":
        save_function = Save()
        model = get_model()
        train_gen = get_train_input()
        test_gen = get_test_input()
        train(model, train_gen, test_gen, save_function)
        # test(model, test_gen)

    使用TensorFlow Lite Model Maker训练模型

    pip install -q tflite-model-maker
    import os
    
    import numpy as np
    
    import tensorflow as tf
    assert tf.__version__.startswith('2')
    
    from tflite_model_maker import configs
    from tflite_model_maker import ExportFormat
    from tflite_model_maker import image_classifier
    from tflite_model_maker import ImageClassifierDataLoader
    from tflite_model_maker import model_spec
    
    import matplotlib.pyplot as plt
    
    image_path='E://PycharmProjects//LearingImageData//DataSets//train'
    model_dir='E://PycharmProjects//MLTest//tflite_model'
    
    data = ImageClassifierDataLoader.from_folder(image_path)
    
    train_data, rest_data = data.split(0.8)
    validation_data, test_data = rest_data.split(0.5)
    
    model = image_classifier.create(train_data, validation_data=validation_data)
    
    model.summary()
    
    loss,accuracy = model.evaluate(test_data)
    
    model.export(export_dir=model_dir)

    该卷积神经网络基于EfficientNet-Lite0,效果比前面两个好得多

  • 相关阅读:
    笔记04_正确使用Heterogeneous元件
    java网络通信:伪异步I/O编程(PIO)
    java网络通信:异步非阻塞I/O (NIO)
    lua源码学习篇二:语法分析
    lua源码学习篇三:赋值表达式解析的流程
    java网络通信:netty
    lua源码学习篇一:环境部署
    lua源码学习篇四:字节码指令
    java网络通信:同步阻塞式I/O模型(BIO)
    前端项目开发流程
  • 原文地址:https://www.cnblogs.com/robotpaul/p/14507376.html
Copyright © 2011-2022 走看看