zoukankan      html  css  js  c++  java
  • 使用keras构建简单的网络分类鸢尾花

    Tensorflow =1.8.0

    # -*- coding: utf-8 -*-
    from warnings import simplefilter
    simplefilter(action='ignore', category=FutureWarning)
    
    import numpy as np
    import pandas as pd
    from keras.models import Sequential     # 链式构建模型
    from keras.layers import Dense  # 全连接层
    from keras.wrappers.scikit_learn import KerasClassifier
    from keras.utils import np_utils
    from sklearn.model_selection import cross_val_score     # 交叉验证
    from sklearn.model_selection import KFold   # 数据分割,1个作为test,k-1个作为train
    from sklearn.preprocessing import LabelEncoder
    from keras.models import model_from_json   # 模型保存
    
    
    
    # reproducibility
    seed = 13
    np.random.seed(seed)
    
    #load data
    df = pd.read_csv('iris.csv')
    X = df.values[:, 1:5].astype(float)
    Y = df.values[:, 5]
    
    encoder = LabelEncoder()
    Y_encoder = encoder.fit_transform(Y) # 把文字标签变成数字标签
    Y_onehot = np_utils.to_categorical(Y_encoder) # convert to one_hot label
    
    # input=4,hidden=7,output=3
    def baseline_model():
        model=Sequential()
        model.add(Dense(7, input_dim=4,activation='tanh'))
        model.add(Dense(3, activation='softmax'))
        model.compile(loss='mean_squared_error',optimizer='sgd',metrics=['accuracy'])
        return model
    
    estimator = KerasClassifier(build_fn=baseline_model, epochs=20, batch_size=1, verbose=1)
    
    # evalute
    kfold=KFold(n_splits=10,shuffle=True, random_state=seed)
    result = cross_val_score(estimator, X, Y_onehot, cv=kfold)
    print("Accuray of cross validation, mean %.2f, std %.2f" %(result.mean(),result.std()))
    
    # save model
    estimator.fit(X,Y_onehot)
    model_json =estimator.model.to_json()
    with open("model.json","w") as json_file:
        json_file.write(model_json)
    
    estimator.model.save_weights("model.h5")
    print("save model to disk")
    
    # load model and use it for prediction
    json_file=open("model.json","r")
    loaded_model_json=json_file.read()
    json_file.close()
    
    loaded_model=model_from_json(loaded_model_json)
    loaded_model.load_weights("model.h5")
    print("loaded model from disk")
    
    predicted = loaded_model.predict(X)
    print("predicted probability" + str(predicted))
    
    predicted_label=loaded_model.predict_classes(X)
    print("predicted label:" + str(predicted_label))
    
  • 相关阅读:
    poj 3068 Bridge Across Islands
    XidianOJ 1086 Flappy v8
    XidianOJ 1036 分配宝藏
    XidianOJ 1090 爬树的V8
    XidianOJ 1088 AK后的V8
    XidianOJ 1062 Black King Bar
    XidianOJ 1091 看Dota视频的V8
    XidianOJ 1098 突击数论前的xry111
    XidianOJ 1019 自然数的秘密
    XidianOJ 1109 Too Naive
  • 原文地址:https://www.cnblogs.com/long5683/p/12885760.html
Copyright © 2011-2022 走看看