zoukankan      html  css  js  c++  java
  • 【TensorFlow】使用TensorFlow Lite Model Maker训练模型

    基础LENET5模型

    import tensorflow as tf
    from keras.preprocessing.image import ImageDataGenerator
    from tensorflow.keras.optimizers import RMSprop
    
    train_dir="E://PycharmProjects//LearingImageData//DataSets//train"
    validation_dir="E://PycharmProjects//LearingImageData//DataSets//var"
    
    train_datagen= ImageDataGenerator(
        rescale=1./255,
        rotation_range=40,#随机旋转,-40°到40°
        width_shift_range=0.2,#平移宽度
        height_shift_range=0.2,#平移高度
        shear_range=0.2,#错切0-1
        zoom_range=0.2,#放大倍数
        horizontal_flip=True,#翻转
        fill_mode='nearest')   #统一化处理
    train_generator=train_datagen.flow_from_directory(
        train_dir,
        target_size=(300,300),
        batch_size=128,#每次抽出128张图片进行训练
        class_mode='categorical'#如果只有两类,用binary,多个类别用categorical
    )
    
    test_datagen= ImageDataGenerator(rescale=1./255)   #统一化处理
    validation_generator=test_datagen.flow_from_directory(
        validation_dir,
        target_size=(300,300),
        batch_size=128,#每次抽出128张图片进行训练
        class_mode='binary'#如果只有两类,用binary,多个类别用categorical
    )
    
    #搭建一个简单的CNN模型
    model= tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(16,(3,3),activation='relu',
                               input_shape=(300,300,3)),
        tf.keras.layers.MaxPool2D(2,2),
        tf.keras.layers.Conv2D(32,(3,3),activation='relu'),#每个卷积层后面都有一个最大池化层
        tf.keras.layers.MaxPool2D(2,2),
        tf.keras.layers.Conv2D(64,(3,3),activation='relu'),
        tf.keras.layers.MaxPool2D(2,2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512,activation='relu'),
        tf.keras.layers.Dense(1,activation='sigmoid')#3类,0,1,2
    ])
    
    #model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
     #             optimizer=RMSprop(lr=0.001),
      #            metrics=['accuracy'])
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    model.summary()
    history=model.fit(
        train_generator,
        epochs=20,
        validation_data=validation_generator,
        steps_per_epoch=2,
        validation_steps=1,
        verbose=2)
    model.save("E://PycharmProjects//LearingImageData//DataSets//finalModel")

    基于MobileNet的改进模型(自动保存准确度大于0.8的模型并转化成tflite)

    import tensorflow as tf
    import keras
    import keras_applications.mobilenet_v2 as mobilenet
    import numpy as np
    
    
    def get_model():
        # 获取base-model
        base_model = mobilenet.MobileNetV2(include_top=False,
                                           weights='imagenet',
                                           input_shape=(224, 224, 3),
                                           backend=keras.backend,
                                           layers=keras.layers,
                                           models=keras.models,
                                           utils=keras.utils)
        # 冻结base-model所有参数
        for layer in base_model.layers:
            layer.trainable = False
        # global average pooling & classification layer
        gm = keras.layers.GlobalAveragePooling2D()(base_model.output)
        fc = keras.layers.Dense(units=3,
                                activation='softmax',
                                kernel_regularizer=tf.keras.regularizers.l1(0.005))(gm)
        # define model
        model = keras.models.Model(base_model.input, fc, name='zc_classifier')
        return model
    
    
    def get_train_input():
        train_image = keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, samplewise_center=False,
                                                                   featurewise_std_normalization=False,
                                                                   samplewise_std_normalization=False, zca_whitening=False,
                                                                   zca_epsilon=1e-06, rotation_range=0.1,
                                                                   width_shift_range=0.1, height_shift_range=0.1,
                                                                   shear_range=0.0,
                                                                   zoom_range=0.1, channel_shift_range=0.0,
                                                                   fill_mode='nearest', cval=0.0, horizontal_flip=True,
                                                                   vertical_flip=False, rescale=None,
                                                                   preprocessing_function=lambda
                                                                       x: mobilenet.preprocess_input(x,
                                                                                                     data_format="channels_last"),
                                                                   data_format="channels_last")
    
        train_gen = train_image.flow_from_directory(r'E:PycharmProjectsLearingImageDataDataSets	rain',
                                                    target_size=(224, 224),
                                                    batch_size=32,
                                                    interpolation="bilinear")
        for img, lab in train_gen:
            classes = np.argmax(lab, axis=-1)
            lab[classes == 0] = [0.85, 0.075, 0.075]
            lab[classes == 1] = [0.05, 0.8, 0.15]
            lab[classes == 2] = [0.05, 0.15, 0.8]
            yield img, lab
    
    
    def get_test_input():
        test_image = keras.preprocessing.image.ImageDataGenerator(preprocessing_function=
                                                                  lambda x: mobilenet.preprocess_input(x,
                                                                                                       data_format="channels_last"),
                                                                  data_format="channels_last", )
        test_gen = test_image.flow_from_directory(r'E:PycharmProjectsLearingImageDataDataSetsvar',
                                                  target_size=(224, 224),
                                                  batch_size=1,
                                                  shuffle=False,
                                                  interpolation="bilinear", )
        return test_gen
    
    
    def custom_loss(y_true, y_pred):
        class_weight = [0.33, 0.33, 0.34]  #配权重 数量越多,权重越低
        # classes = tf.argmax(y_true, axis=1)
        loss = 0.0
        for i in range(3):
            loss += (y_true[:, i] * tf.math.log(y_pred[:, i]) +
                     (1 - y_true[:, i]) * tf.math.log(1 - y_pred[:, i])) * class_weight[i]
        return -tf.reduce_mean(loss)
    
    
    def train(model: tf.keras.Model, train_gen, test_gen, save_function):
        adam = keras.optimizers.Adam(0.01)#
        model.compile(loss=custom_loss, optimizer=adam, metrics=['accuracy'])
        model.fit_generator(train_gen,
                            epochs=8,
                            steps_per_epoch=6,
                            validation_data=test_gen,
                            callbacks=[save_function])
        model.save('./zc.h5')
    
    
    def test(model: tf.keras.Model, test_gen):
        # import numpy as np
        # adam = keras.optimizers.Adam(0.001)
        # model.compile(loss="categorical_crossentropy", optimizer=adam, metrics=['accuracy'])
        # model.load_weights('./zc.h5')
        print(model.evaluate_generator(test_gen, verbose=1))
        # for img, lab in test_gen:
        #     print(np.argmax(lab[0]), np.argmax(model.predict(img)))
    
    
    class Save(keras.callbacks.Callback):
        def __init__(self):
            super(Save, self).__init__()
            self.max_acc = 0.0
    
        def on_epoch_begin(self, epoch, logs=None):
            pass
    
        def on_epoch_end(self, epoch, logs=None):
            self.val_acc = logs["accuracy"]
            if epoch != 0:
                if self.val_acc > self.max_acc and self.val_acc > 0.8:
                    model.save("kears_model_" + str(epoch) + "_acc=" + str(self.val_acc) + ".h5")
                    self.max_acc = self.val_acc
                    converter = tf.lite.TFLiteConverter.from_keras_model(model)
                    converter.experimental_new_converter = True
                    tflite_model = converter.convert()
                    open("converted_model_"+str(self.val_acc)+".tflite", "wb").write(tflite_model)
    
    
    if __name__ == "__main__":
        save_function = Save()
        model = get_model()
        train_gen = get_train_input()
        test_gen = get_test_input()
        train(model, train_gen, test_gen, save_function)
        # test(model, test_gen)

    使用TensorFlow Lite Model Maker训练模型

    pip install -q tflite-model-maker
    import os
    
    import numpy as np
    
    import tensorflow as tf
    assert tf.__version__.startswith('2')
    
    from tflite_model_maker import configs
    from tflite_model_maker import ExportFormat
    from tflite_model_maker import image_classifier
    from tflite_model_maker import ImageClassifierDataLoader
    from tflite_model_maker import model_spec
    
    import matplotlib.pyplot as plt
    
    image_path='E://PycharmProjects//LearingImageData//DataSets//train'
    model_dir='E://PycharmProjects//MLTest//tflite_model'
    
    data = ImageClassifierDataLoader.from_folder(image_path)
    
    train_data, rest_data = data.split(0.8)
    validation_data, test_data = rest_data.split(0.5)
    
    model = image_classifier.create(train_data, validation_data=validation_data)
    
    model.summary()
    
    loss,accuracy = model.evaluate(test_data)
    
    model.export(export_dir=model_dir)

    该卷积神经网络基于EfficientNet-Lite0,效果比前面两个好得多

  • 相关阅读:
    web前端的发展态势
    AngularJs 简单入门
    css代码优化篇
    git提交报错:Please make sure you have the correct access rights and the repository exists.
    Activiti工作流框架学习
    遍历map集合的4种方法
    js设置日期、月份增加减少
    Invalid character found in the request target. The valid characters are defined in RFC 7230 and RFC 3986
    webservice_rest接口_学习笔记
    相互匹配两个list集合+动态匹配${}参数
  • 原文地址:https://www.cnblogs.com/robotpaul/p/14507376.html
Copyright © 2011-2022 走看看