zoukankan      html  css  js  c++  java
  • BatchNormalization的使用

    # import BatchNormalization
    from keras.layers.normalization import BatchNormalization
    
    # instantiate model
    model = Sequential()
    
    # we can think of this chunk as the input layer
    model.add(Dense(64, input_dim=14, init='uniform'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Dropout(0.5))
    
    # we can think of this chunk as the hidden layer    
    model.add(Dense(64, init='uniform'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Dropout(0.5))
    
    # we can think of this chunk as the output layer
    model.add(Dense(2, init='uniform'))
    model.add(BatchNormalization())
    model.add(Activation('softmax'))
    
    # setting up the optimization of our weights 
    sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss='binary_crossentropy', optimizer=sgd)
    
    # running the fitting
    model.fit(X_train, y_train, nb_epoch=20, batch_size=16, show_accuracy=True, validation_split=0.2, verbose = 2)
  • 相关阅读:
    STL之vector
    STL之string
    STL之map
    STL之queue
    STL之set
    Ubuntu 12.04 输入法托盘图标消失
    Hibernate:No row with the given identifier exists
    Java 数组
    Oracle 简单备份 批处理(BAT)
    Oracle DataBase Link
  • 原文地址:https://www.cnblogs.com/ggzhangxiaochao/p/9051343.html
Copyright © 2011-2022 走看看