zoukankan      html  css  js  c++  java
  • 吴裕雄 python神经网络(8)

    # -*- coding=utf-8 -*-
    import numpy as np
    import keras
    from keras.models import Sequential
    from keras.layers import Dense,Flatten,Dropout
    from keras.optimizers import Adadelta
    from keras.datasets import cifar10
    from keras import applications

    import matplotlib.pyplot as plt
    %matplotlib inline

    vgg_model=applications.VGG19(include_top=False,weights='imagenet')
    vgg_model.summary()

    (train_x,train_y),(test_x,test_y)=cifar10.load_data()
    print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

    n_classes=10
    train_y=keras.utils.to_categorical(train_y,n_classes)
    test_y=keras.utils.to_categorical(test_y,n_classes)

    bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
    bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

    print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

    my_model=Sequential()
    my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
    my_model.add(Dense(512,activation='relu'))
    my_model.add(Dropout(0.5))
    my_model.add(Dense(256,activation='relu'))
    my_model.add(Dropout(0.5))
    my_model.add(Dense(n_classes,activation='softmax'))
    my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",
    metrics=['accuracy'])
    my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

    evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
    print("loss:",evaluation[0],"accuracy:",evaluation[1])

    def predict_label(img_idx,show_proba=True):
    plt.imshow(train_x[img_idx],aspect='auto')
    plt.title("Image to be labeled")
    plt.show()
    img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
    prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
    print("Actual class:{0} Predict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

    if show_proba:
    pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
    print(pred)

    for i in range(3):
    predict_label(i,show_proba=True)

  • 相关阅读:
    时空权衡之计数排序
    何时发生隐式类型转换
    常量指针与指针常量的区别
    虚函数有关面试题
    C++中数组定义及初始化
    InputStream类的available()方法
    FORK()函数
    面向对象三大基本特性,五大基本原则
    SpringMVC工作原理
    java文件的上传
  • 原文地址:https://www.cnblogs.com/tszr/p/10089250.html
Copyright © 2011-2022 走看看