zoukankan      html  css  js  c++  java
  • keras用vgg16做图像分类

    实际上我只是提供一个模版而已,代码应该很容易看得懂,label是存在一个csv里面的,图片是在一个文件夹里面的

    没GPU的就不用尝试了,训练一次要很久很久。。。

    ## import libaries
    import pandas as pd
    import numpy as np
    from skimage import io
    import os, sys
    from tqdm import tqdm
    
    ## load data
    train = pd.read_csv('./data/data/train.csv')
    test = pd.read_csv('./data/data/test.csv')
    
    def read_img(img_path):
        img = io.imread(img_path)
        return img
    
    ## set path for images
    TRAIN_PATH = './data/data/train_img/'
    TEST_PATH = './data/data/test_img/'
    
    
    # load data
    train_img, test_img = [],[]
    for img_path in tqdm(train['image_id'].values):
        train_img.append(read_img(TRAIN_PATH + img_path + '.png'))
    
    for img_path in tqdm(test['image_id'].values):
        test_img.append(read_img(TEST_PATH + img_path + '.png'))
    
    # normalize images
    x_train = np.array(train_img, np.float32) / 255.
    x_test = np.array(test_img, np.float32) / 255.
    
    # target variable - encoding numeric value
    label_list = train['label'].tolist()
    Y_train = {k:v+1 for v,k in enumerate(set(label_list))}
    y_train = [Y_train[k] for k in label_list]   
    y_train = np.array(y_train)
    
    
    from keras import applications
    from keras.models import Model
    from keras import optimizers
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Flatten
    from keras.layers import Conv2D, MaxPooling2D
    from keras.layers.normalization import BatchNormalization
    from keras.metrics import categorical_accuracy
    from keras.preprocessing.image import ImageDataGenerator
    from keras.callbacks import EarlyStopping
    from keras.utils import to_categorical
    from keras.preprocessing.image import ImageDataGenerator
    from keras.callbacks import ModelCheckpoint
    
    y_train = to_categorical(y_train)
    
    #Transfer learning with Inception V3 
    base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
    
    ## set model architechture 
    add_model = Sequential()
    add_model.add(Flatten(input_shape=base_model.output_shape[1:]))
    add_model.add(Dense(256, activation='relu'))
    add_model.add(Dense(y_train.shape[1], activation='softmax'))
    
    model = Model(inputs=base_model.input, outputs=add_model(base_model.output))
    model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
                metrics=['accuracy'])
    
    model.summary()
    
    batch_size = 128 # tune it
    epochs = 30 # increase it
    print ("Hello")
    train_datagen = ImageDataGenerator(
            shear_range=0.2,
            zoom_range=0.2,
            rotation_range=30, 
            width_shift_range=0.1,
            height_shift_range=0.1, 
            horizontal_flip=True)
    train_datagen.fit(x_train)
    
    history = model.fit_generator(
        train_datagen.flow(x_train, y_train, batch_size=batch_size),
        steps_per_epoch=x_train.shape[0] // batch_size,
        epochs=epochs,
        callbacks=[ModelCheckpoint('VGG16-transferlearning2.model', monitor='val_acc', save_best_only=True)]
    )
    
    ## predict test data
    predictions = model.predict(x_test)
    
    
    # get labels
    predictions = np.argmax(predictions, axis=1)
    rev_y = {v:k for k,v in Y_train.items()}
    pred_labels = [rev_y[k] for k in predictions]
    
    ## make submission
    sub = pd.DataFrame({'image_id':test.image_id, 'label':pred_labels})
    sub.to_csv('sub_vgg2.csv', index=False) ## ~0.59
  • 相关阅读:
    阿里云安装Mono 发生错误解决方法
    在Entity Framework 中执行Tsql语句
    WinRT app guide
    开源稳定的消息队列 RabbitMQ
    Catpic: OpenSocial Container on .NET
    MSDTC 故障排除
    HTML5 canvas图形库RGraph
    《我的WCF之旅》博文系列汇总
    TSQL Enhancement in SQL Server 2005[下篇]
    谈谈基于SQL Server 的Exception Handling[上篇]
  • 原文地址:https://www.cnblogs.com/qscqesze/p/7560159.html
Copyright © 2011-2022 走看看