zoukankan      html  css  js  c++  java
  • 吴裕雄 python神经网络(8)

    # -*- coding=utf-8 -*-
    import numpy as np
    import keras
    from keras.models import Sequential
    from keras.layers import Dense,Flatten,Dropout
    from keras.optimizers import Adadelta
    from keras.datasets import cifar10
    from keras import applications

    import matplotlib.pyplot as plt
    %matplotlib inline

    vgg_model=applications.VGG19(include_top=False,weights='imagenet')
    vgg_model.summary()

    (train_x,train_y),(test_x,test_y)=cifar10.load_data()
    print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

    n_classes=10
    train_y=keras.utils.to_categorical(train_y,n_classes)
    test_y=keras.utils.to_categorical(test_y,n_classes)

    bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
    bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

    print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

    my_model=Sequential()
    my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
    my_model.add(Dense(512,activation='relu'))
    my_model.add(Dropout(0.5))
    my_model.add(Dense(256,activation='relu'))
    my_model.add(Dropout(0.5))
    my_model.add(Dense(n_classes,activation='softmax'))
    my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",
    metrics=['accuracy'])
    my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

    evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
    print("loss:",evaluation[0],"accuracy:",evaluation[1])

    def predict_label(img_idx,show_proba=True):
    plt.imshow(train_x[img_idx],aspect='auto')
    plt.title("Image to be labeled")
    plt.show()
    img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
    prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
    print("Actual class:{0} Predict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

    if show_proba:
    pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
    print(pred)

    for i in range(3):
    predict_label(i,show_proba=True)

  • 相关阅读:
    75. Sort Colors
    101. Symmetric Tree
    121. Best Time to Buy and Sell Stock
    136. Single Number
    104. Maximum Depth of Binary Tree
    70. Climbing Stairs
    64. Minimum Path Sum
    62. Unique Paths
    css知识点3
    css知识点2
  • 原文地址:https://www.cnblogs.com/tszr/p/10089250.html
Copyright © 2011-2022 走看看