zoukankan      html  css  js  c++  java
  • keras用vgg16做图像分类

    实际上我只是提供一个模版而已,代码应该很容易看得懂,label是存在一个csv里面的,图片是在一个文件夹里面的

    没GPU的就不用尝试了,训练一次要很久很久。。。

    ## import libaries
    import pandas as pd
    import numpy as np
    from skimage import io
    import os, sys
    from tqdm import tqdm
    
    ## load data
    train = pd.read_csv('./data/data/train.csv')
    test = pd.read_csv('./data/data/test.csv')
    
    def read_img(img_path):
        img = io.imread(img_path)
        return img
    
    ## set path for images
    TRAIN_PATH = './data/data/train_img/'
    TEST_PATH = './data/data/test_img/'
    
    
    # load data
    train_img, test_img = [],[]
    for img_path in tqdm(train['image_id'].values):
        train_img.append(read_img(TRAIN_PATH + img_path + '.png'))
    
    for img_path in tqdm(test['image_id'].values):
        test_img.append(read_img(TEST_PATH + img_path + '.png'))
    
    # normalize images
    x_train = np.array(train_img, np.float32) / 255.
    x_test = np.array(test_img, np.float32) / 255.
    
    # target variable - encoding numeric value
    label_list = train['label'].tolist()
    Y_train = {k:v+1 for v,k in enumerate(set(label_list))}
    y_train = [Y_train[k] for k in label_list]   
    y_train = np.array(y_train)
    
    
    from keras import applications
    from keras.models import Model
    from keras import optimizers
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Flatten
    from keras.layers import Conv2D, MaxPooling2D
    from keras.layers.normalization import BatchNormalization
    from keras.metrics import categorical_accuracy
    from keras.preprocessing.image import ImageDataGenerator
    from keras.callbacks import EarlyStopping
    from keras.utils import to_categorical
    from keras.preprocessing.image import ImageDataGenerator
    from keras.callbacks import ModelCheckpoint
    
    y_train = to_categorical(y_train)
    
    #Transfer learning with Inception V3 
    base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
    
    ## set model architechture 
    add_model = Sequential()
    add_model.add(Flatten(input_shape=base_model.output_shape[1:]))
    add_model.add(Dense(256, activation='relu'))
    add_model.add(Dense(y_train.shape[1], activation='softmax'))
    
    model = Model(inputs=base_model.input, outputs=add_model(base_model.output))
    model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
                metrics=['accuracy'])
    
    model.summary()
    
    batch_size = 128 # tune it
    epochs = 30 # increase it
    print ("Hello")
    train_datagen = ImageDataGenerator(
            shear_range=0.2,
            zoom_range=0.2,
            rotation_range=30, 
            width_shift_range=0.1,
            height_shift_range=0.1, 
            horizontal_flip=True)
    train_datagen.fit(x_train)
    
    history = model.fit_generator(
        train_datagen.flow(x_train, y_train, batch_size=batch_size),
        steps_per_epoch=x_train.shape[0] // batch_size,
        epochs=epochs,
        callbacks=[ModelCheckpoint('VGG16-transferlearning2.model', monitor='val_acc', save_best_only=True)]
    )
    
    ## predict test data
    predictions = model.predict(x_test)
    
    
    # get labels
    predictions = np.argmax(predictions, axis=1)
    rev_y = {v:k for k,v in Y_train.items()}
    pred_labels = [rev_y[k] for k in predictions]
    
    ## make submission
    sub = pd.DataFrame({'image_id':test.image_id, 'label':pred_labels})
    sub.to_csv('sub_vgg2.csv', index=False) ## ~0.59
  • 相关阅读:
    求取32位无符号整数中最低位位值为1的位置 && 求取32位无符号整数中最高位位值为1的位置
    交换寄存器中的相应字段
    NDK与JNI
    plt_0
    32位无符号整数平方根
    提取 主 设备号
    爱江山更爱美人
    mysql oracle sqlserver 数据库分页
    详解JDBC驱动的四种类型
    oracle sqlplus 中的清屏命令
  • 原文地址:https://www.cnblogs.com/qscqesze/p/7560159.html
Copyright © 2011-2022 走看看