zoukankan      html  css  js  c++  java
  • 吴裕雄 python神经网络(8)

    # -*- coding=utf-8 -*-
    import numpy as np
    import keras
    from keras.models import Sequential
    from keras.layers import Dense,Flatten,Dropout
    from keras.optimizers import Adadelta
    from keras.datasets import cifar10
    from keras import applications

    import matplotlib.pyplot as plt
    %matplotlib inline

    vgg_model=applications.VGG19(include_top=False,weights='imagenet')
    vgg_model.summary()

    (train_x,train_y),(test_x,test_y)=cifar10.load_data()
    print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

    n_classes=10
    train_y=keras.utils.to_categorical(train_y,n_classes)
    test_y=keras.utils.to_categorical(test_y,n_classes)

    bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
    bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

    print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

    my_model=Sequential()
    my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
    my_model.add(Dense(512,activation='relu'))
    my_model.add(Dropout(0.5))
    my_model.add(Dense(256,activation='relu'))
    my_model.add(Dropout(0.5))
    my_model.add(Dense(n_classes,activation='softmax'))
    my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",
    metrics=['accuracy'])
    my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

    evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
    print("loss:",evaluation[0],"accuracy:",evaluation[1])

    def predict_label(img_idx,show_proba=True):
    plt.imshow(train_x[img_idx],aspect='auto')
    plt.title("Image to be labeled")
    plt.show()
    img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
    prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
    print("Actual class:{0} Predict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

    if show_proba:
    pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
    print(pred)

    for i in range(3):
    predict_label(i,show_proba=True)

  • 相关阅读:
    产生一个int数组,长度为100,并向其中随机插入1-100,并且不能重复。
    it人必进的几大网站
    可写可选dropdownlist(只测试过ie)
    Datatable转换为Json 的方法
    ref 和out的区别
    数据库事务
    Webservice 的安全策略
    【转】Zookeeper解析、安装、配置
    【转】activemq的几种基本通信方式总结
    【转】Java小应用:Eclipse中建立自己的类库,给不同的工程使用
  • 原文地址:https://www.cnblogs.com/tszr/p/10089250.html
Copyright © 2011-2022 走看看