zoukankan      html  css  js  c++  java
  • keras用vgg16做图像分类

    实际上我只是提供一个模版而已,代码应该很容易看得懂,label是存在一个csv里面的,图片是在一个文件夹里面的

    没GPU的就不用尝试了,训练一次要很久很久。。。

    ## import libaries
    import pandas as pd
    import numpy as np
    from skimage import io
    import os, sys
    from tqdm import tqdm
    
    ## load data
    train = pd.read_csv('./data/data/train.csv')
    test = pd.read_csv('./data/data/test.csv')
    
    def read_img(img_path):
        img = io.imread(img_path)
        return img
    
    ## set path for images
    TRAIN_PATH = './data/data/train_img/'
    TEST_PATH = './data/data/test_img/'
    
    
    # load data
    train_img, test_img = [],[]
    for img_path in tqdm(train['image_id'].values):
        train_img.append(read_img(TRAIN_PATH + img_path + '.png'))
    
    for img_path in tqdm(test['image_id'].values):
        test_img.append(read_img(TEST_PATH + img_path + '.png'))
    
    # normalize images
    x_train = np.array(train_img, np.float32) / 255.
    x_test = np.array(test_img, np.float32) / 255.
    
    # target variable - encoding numeric value
    label_list = train['label'].tolist()
    Y_train = {k:v+1 for v,k in enumerate(set(label_list))}
    y_train = [Y_train[k] for k in label_list]   
    y_train = np.array(y_train)
    
    
    from keras import applications
    from keras.models import Model
    from keras import optimizers
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Flatten
    from keras.layers import Conv2D, MaxPooling2D
    from keras.layers.normalization import BatchNormalization
    from keras.metrics import categorical_accuracy
    from keras.preprocessing.image import ImageDataGenerator
    from keras.callbacks import EarlyStopping
    from keras.utils import to_categorical
    from keras.preprocessing.image import ImageDataGenerator
    from keras.callbacks import ModelCheckpoint
    
    y_train = to_categorical(y_train)
    
    #Transfer learning with Inception V3 
    base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
    
    ## set model architechture 
    add_model = Sequential()
    add_model.add(Flatten(input_shape=base_model.output_shape[1:]))
    add_model.add(Dense(256, activation='relu'))
    add_model.add(Dense(y_train.shape[1], activation='softmax'))
    
    model = Model(inputs=base_model.input, outputs=add_model(base_model.output))
    model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
                metrics=['accuracy'])
    
    model.summary()
    
    batch_size = 128 # tune it
    epochs = 30 # increase it
    print ("Hello")
    train_datagen = ImageDataGenerator(
            shear_range=0.2,
            zoom_range=0.2,
            rotation_range=30, 
            width_shift_range=0.1,
            height_shift_range=0.1, 
            horizontal_flip=True)
    train_datagen.fit(x_train)
    
    history = model.fit_generator(
        train_datagen.flow(x_train, y_train, batch_size=batch_size),
        steps_per_epoch=x_train.shape[0] // batch_size,
        epochs=epochs,
        callbacks=[ModelCheckpoint('VGG16-transferlearning2.model', monitor='val_acc', save_best_only=True)]
    )
    
    ## predict test data
    predictions = model.predict(x_test)
    
    
    # get labels
    predictions = np.argmax(predictions, axis=1)
    rev_y = {v:k for k,v in Y_train.items()}
    pred_labels = [rev_y[k] for k in predictions]
    
    ## make submission
    sub = pd.DataFrame({'image_id':test.image_id, 'label':pred_labels})
    sub.to_csv('sub_vgg2.csv', index=False) ## ~0.59
  • 相关阅读:
    【BZOJ4890】[TJOI2017]城市(动态规划)
    【BZOJ4887】[TJOI2017]可乐(矩阵快速幂)
    【BZOJ4873】[六省联考2017]寿司餐厅(网络流)
    【BZOJ4868】[六省联考2017]期末考试(贪心)
    【Luogu3733】[HAOI2017]八纵八横(线性基,线段树分治)
    【Luogu3732】[HAOI2017]供给侧改革(Trie树)
    【Luogu3731】[HAOI2017]新型城市化(网络流,Tarjan)
    【BZOJ5332】[SDOI2018]旧试题(数论,三元环计数)
    28人买可乐喝,3个可乐瓶盖可以换一瓶可乐,那么要买多少瓶可乐,够28人喝?假如是50人,又需要买多少瓶可乐?
    基础学习day07---面向对象三---继承,接口与 抽象类
  • 原文地址:https://www.cnblogs.com/qscqesze/p/7560159.html
Copyright © 2011-2022 走看看