zoukankan      html  css  js  c++  java
  • reshape、shuffle、save_weights

    #-*- coding: utf-8 -*-
    
    import pandas as pd
    from random import shuffle
    import matplotlib.pyplot as plt #导入Matplotlib
    
    datafile = '../data/model.xls'
    data = pd.read_excel(datafile)
    data = data.as_matrix()
    shuffle(data)
    
    p = 0.8 #设置训练数据比例
    train = data[:int(len(data)*p),:]
    test = data[int(len(data)*p):,:]
    
    #构建LM神经网络模型
    from keras.models import Sequential #导入神经网络初始化函数
    from keras.layers.core import Dense, Activation #导入神经网络层函数、激活函数
    
    netfile = '../tmp/net.model' #构建的神经网络模型存储路径
    
    net = Sequential() #建立神经网络
    net.add(Dense(input_dim = 3, output_dim = 10)) #添加输入层(3节点)到隐藏层(10节点)的连接
    net.add(Activation('relu')) #隐藏层使用relu激活函数
    net.add(Dense(input_dim = 10, output_dim = 1)) #添加隐藏层(10节点)到输出层(1节点)的连接
    net.add(Activation('sigmoid')) #输出层使用sigmoid激活函数
    net.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics=['accuracy']) #编译模型,使用adam方法求解
    
    net.fit(train[:,:3], train[:,3], nb_epoch=50, batch_size=1) #训练模型,循环1000次
    net.save_weights(netfile) #保存模型
    #print(net.predict_classes(train[:,:3]))
    # [[1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [0]
    #  [1]
    #  [0]
    #  [1]
    #  [1]
    #  [0]
    predict_result = net.predict_classes(train[:,:3]).reshape(len(train)) #预测结果变形
    #print(predict_result)
    #[1 1 1 1 1 1 1 1 0 1 0 1 1 0 0]
    '''这里要提醒的是,keras用predict给出预测概率,predict_classes才是给出预测类别,而且两者的预测结果都是n x 1维数组,而不是通常的 1 x n'''
    
    # from cm_plot import * #导入自行编写的混淆矩阵可视化函数
    # cm_plot(train[:,3], predict_result).show() #显示混淆矩阵可视化结果
    
    from sklearn.metrics import roc_curve #导入ROC曲线函数
    
    predict_result = net.predict(test[:,:3]).reshape(len(test))
    fpr, tpr, thresholds = roc_curve(test[:,3], predict_result, pos_label=1)
    plt.plot(fpr, tpr, linewidth=2, label = 'ROC of LM') #作出ROC曲线
    plt.xlabel('False Positive Rate') #坐标轴标签
    plt.ylabel('True Positive Rate') #坐标轴标签
    plt.ylim(0,1.05) #边界范围
    plt.xlim(0,1.05) #边界范围
    plt.legend(loc=4) #图例
    plt.show() #显示作图结果
  • 相关阅读:
    软件定义网络实验4:Open vSwitch 实验——Mininet 中使用 OVS 命令(实验过程及结果记录)
    软件定义网络实验3:测量路径的损耗率 (实验过程及结果记录)
    第一次个人编程作业
    软件定义网络实验2:Mininet拓扑的命令脚本生成(实验过程及结果记录)
    软件定义网络实验1:Mininet源码安装和可视化拓扑工具(实验过程及结果记录)
    第一次博客作业
    第07组(69) 需求分析报告
    第七组(69)团队展示
    第三次作业
    结对编程作业
  • 原文地址:https://www.cnblogs.com/ggzhangxiaochao/p/9115295.html
Copyright © 2011-2022 走看看