zoukankan      html  css  js  c++  java
  • Keras人工神经网络多分类(SGD)

    import numpy as np
    import pandas as pd
    from keras.models import Sequential
    from keras.layers import Dense, Dropout
    from keras.wrappers.scikit_learn import KerasClassifier
    from keras.utils import np_utils
    from sklearn.model_selection import train_test_split, KFold, cross_val_score
    from sklearn.preprocessing import LabelEncoder
    from keras.optimizers import SGD
    from keras.layers import LSTM
    
    # load dataset
    dataframe = pd.read_csv("./data/iris1.csv", header=None)
    dataset = dataframe.values
    X = dataset[:, 0:19].astype(float)
    dummy_y1 = dataset[:, 19]
    m,n=1682,6
    dum_imax=np.zeros((m,n))
    # print(type(dum_imax))
    for i in range(m):
        # print(i)
        # exit()
        if dummy_y1[i]!=0:
            dum_imax[i][dummy_y1[i]-1]=1
        else:
            dum_imax[i][5]=1
    # print(dum_imax)
    dummy_y =dum_imax
    print(dummy_y)
    print(type(dummy_y[0][0]))
    
    def baseline_model():
        model = Sequential()
        model.add(Dense(output_dim=50, input_dim=19, activation='relu'))
        # model.add(LSTM(128))
        model.add(Dropout(0.4))
        model.add(Dense(output_dim=6, input_dim=50, activation='softmax'))
        # Compile model
        sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
        # model.compile(loss='categorical_crossentropy', optimizer=sgd)
        #编译模型。由于我们做的是二元分类,所以我们指定损失函数为binary_crossentropy,以及模式为binary
        #另外常见的损失函数还有mean_squared_error、categorical_crossentropy等,请阅读帮助文件。
        #求解方法我们指定用adam,还有sgd、rmsprop等可选
        model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
        return model
    estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=40, batch_size=256)
    print(estimator)
    
    # splitting data into training set and test set. If random_state is set to an integer, the split datasets are fixed.
    X_train, X_test, Y_train, Y_test = train_test_split(X, dummy_y, test_size=0.2, random_state=0)#train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,
    print(len(X_train[0]))
    print(len(Y_train[0]))
    estimator.fit(X_train, Y_train,nb_epoch = 100)#训练模型,学习一百次
    
    # make predictions
    print(X_test)
    pred = estimator.predict(X_test)
    print(pred)
    # init_lables = encoder.inverse_transform(pred)
    # print(init_lables)
    
    # inverse numeric variables to initial categorical labels
    # init_lables = encoder.inverse_transform(pred)
    # print(init_lables)
    
    # k-fold cross-validate
    # seed = 42
    # np.random.seed(seed)
    '''
    n_splits : 默认3,最小为2;K折验证的K值
    shuffle : 默认False;shuffle会对数据产生随机搅动(洗牌)
    random_state :默认None,随机种子
    '''
    kfold = KFold(n_splits=5, shuffle=True)#定义5折,在对数据进行划分之前,对数据进行随机混洗
    
    results = cross_val_score(estimator, X, dummy_y, cv=kfold)#在数据集上,使用k fold交叉验证,对估计器estimator进行评估。
    print("baseline:%.2f%%(%.2f%%)"%(results.mean()*100,results.std()*100))#返回的结果,是10次数据集划分后,每次的评估结果。评估结果包括平均准确率和标准差
  • 相关阅读:
    [Angular 9] Built-in template syntax $any
    [Angular 9] Improved Dependency Injection with the new providedIn scopes 'any' and 'platform'
    [Angular 9 Unit testing] Stronger typing for dependency injection in tests
    [Angular] Preserve the current route’s query parameters when navigating with the Angular Router
    [Angular] Do relative routing inside component
    [Typescript] Make your optional fields required in TypeScript
    [Typescript] Exclude Properties from a Type in TypeScript
    [Javascript] Hide Properties from Showing Up in "for ... in" Loops in JavaScript
    [Debug] Set and remove DOM breakpoints
    【职业素养】4种让你显得没教养的做法
  • 原文地址:https://www.cnblogs.com/smuxiaolei/p/8655738.html
Copyright © 2011-2022 走看看