zoukankan      html  css  js  c++  java
  • Keras猫狗大战七:resnet50预训练模型迁移学习优化,动态调整学习率,精度提高到96.2%

    https://www.cnblogs.com/zhengbiqing/p/11780161.html中直接在resnet网络的卷积层后添加一层分类层,得到一个最简单的迁移学习模型,得到的结果为95.3%。

    这里对最后的分类网络做些优化:用GlobalAveragePooling2D替换Flatten、增加一个密集连接层(同时添加BN、Activation、Dropout):

    conv_base = ResNet50(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
    for layers in conv_base.layers[:]:
        layers.trainable = False
        
    x = conv_base.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(0.3)(x)
    predictions = Dense(1, activation='sigmoid')(x)
    model = Model(inputs=conv_base.input, outputs=predictions)

    另外采用动态学习率,并且打印显示出学习率:

    optimizer = optimizers.RMSprop(lr=1e-3)
    
    def get_lr_metric(optimizer):
        def lr(y_true, y_pred):
            return optimizer.lr
    
        return lr
    
    lr_metric = get_lr_metric(optimizer)
    
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['acc',lr_metric])

    当模型的val_loss训练多轮不再下降时,提前结束训练:

    from keras.callbacks import ReduceLROnPlateau,EarlyStopping
    
    early_stop = EarlyStopping(monitor='val_loss', patience=13)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=7, mode='auto', factor=0.2)
    callbacks = [early_stop,reduce_lr]
    
    history = model.fit_generator(
          train_generator,
          steps_per_epoch=train_generator.samples//batch_size,
          epochs=100,
          validation_data=validation_generator,
          validation_steps=validation_generator.samples//batch_size,
          callbacks = callbacks)

    共训练了61epochs,学习率从0.001下降到1.6e-6:

    Epoch 1/100
    281/281 [==============================] - 141s 503ms/step - loss: 0.3322 - acc: 0.8589 - lr: 0.0010 - val_loss: 0.2344 - val_acc: 0.9277 - val_lr: 0.0010
    Epoch 2/100
    281/281 [==============================] - 79s 279ms/step - loss: 0.2591 - acc: 0.8862 - lr: 0.0010 - val_loss: 0.2331 - val_acc: 0.9288 - val_lr: 0.0010
    Epoch 3/100
    281/281 [==============================] - 78s 279ms/step - loss: 0.2405 - acc: 0.8959 - lr: 0.0010 - val_loss: 0.2292 - val_acc: 0.9303 - val_lr: 0.0010
    ......
    281/281 [==============================] - 77s 275ms/step - loss: 0.1532 - acc: 0.9407 - lr: 1.6000e-06 - val_loss: 0.1871 - val_acc: 0.9412 - val_lr: 1.6000e-06
    Epoch 60/100
    281/281 [==============================] - 78s 276ms/step - loss: 0.1492 - acc: 0.9396 - lr: 1.6000e-06 - val_loss: 0.1687 - val_acc: 0.9450 - val_lr: 1.6000e-06
    Epoch 61/100
    281/281 [==============================] - 77s 276ms/step - loss: 0.1468 - acc: 0.9414 - lr: 1.6000e-06 - val_loss: 0.1825 - val_acc: 0.9454 - val_lr: 1.6000e-06

    加载模型:
    optimizer = optimizers.RMSprop(lr=1e-3)
    
    def get_lr_metric(optimizer):
        def lr(y_true, y_pred):
            return optimizer.lr
    
        return lr
    
    lr_metric = get_lr_metric(optimizer)
    model = load_model(model_file, custom_objects={'lr':lr_metric})

    修改混淆矩阵函数,以打印每个类别的精确度:

    def plot_sonfusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks([-0.5,1.5], classes)
    
        print(cm)
        ok_num = 0
        for k in range(cm.shape[0]):
            print(cm[k,k]/np.sum(cm[k,:]))
            ok_num += cm[k,k]
            
        print(ok_num/np.sum(cm))
            
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
        thresh = cm.max() / 2.0
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j], horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')
    
        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predict label')

    测试结果为:

    [[1200   50]
     [  45 1205]]
    0.96
    0.964
    0.962
    猫的准确度为96%,狗的为96.4%,总的准确度为96.2%。混淆矩阵图:

  • 相关阅读:
    上下文调用(call , apply , bind)
    源码学习第七天(水滴石穿)
    学习源码第六天(加油别放弃)
    学习源码第五天(难得可贵)
    学习源码第四天(昨天只看了一点正则,发现正则真的水很深,但很有魅力)
    简单谈谈$.merge()
    学习源码第三天(短暂的坚持)
    学习源码第二天(渐入佳境)
    jquery源码学习第一天
    经典面试题简单分析
  • 原文地址:https://www.cnblogs.com/zhengbiqing/p/11964301.html
Copyright © 2011-2022 走看看