zoukankan      html  css  js  c++  java
  • reshape、shuffle、save_weights

    #-*- coding: utf-8 -*-
    
    import pandas as pd
    from random import shuffle
    import matplotlib.pyplot as plt #导入Matplotlib
    
    datafile = '../data/model.xls'
    data = pd.read_excel(datafile)
    data = data.as_matrix()
    shuffle(data)
    
    p = 0.8 #设置训练数据比例
    train = data[:int(len(data)*p),:]
    test = data[int(len(data)*p):,:]
    
    #构建LM神经网络模型
    from keras.models import Sequential #导入神经网络初始化函数
    from keras.layers.core import Dense, Activation #导入神经网络层函数、激活函数
    
    netfile = '../tmp/net.model' #构建的神经网络模型存储路径
    
    net = Sequential() #建立神经网络
    net.add(Dense(input_dim = 3, output_dim = 10)) #添加输入层(3节点)到隐藏层(10节点)的连接
    net.add(Activation('relu')) #隐藏层使用relu激活函数
    net.add(Dense(input_dim = 10, output_dim = 1)) #添加隐藏层(10节点)到输出层(1节点)的连接
    net.add(Activation('sigmoid')) #输出层使用sigmoid激活函数
    net.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics=['accuracy']) #编译模型,使用adam方法求解
    
    net.fit(train[:,:3], train[:,3], nb_epoch=50, batch_size=1) #训练模型,循环1000次
    net.save_weights(netfile) #保存模型
    #print(net.predict_classes(train[:,:3]))
    # [[1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [1]
    #  [0]
    #  [1]
    #  [0]
    #  [1]
    #  [1]
    #  [0]
    predict_result = net.predict_classes(train[:,:3]).reshape(len(train)) #预测结果变形
    #print(predict_result)
    #[1 1 1 1 1 1 1 1 0 1 0 1 1 0 0]
    '''这里要提醒的是,keras用predict给出预测概率,predict_classes才是给出预测类别,而且两者的预测结果都是n x 1维数组,而不是通常的 1 x n'''
    
    # from cm_plot import * #导入自行编写的混淆矩阵可视化函数
    # cm_plot(train[:,3], predict_result).show() #显示混淆矩阵可视化结果
    
    from sklearn.metrics import roc_curve #导入ROC曲线函数
    
    predict_result = net.predict(test[:,:3]).reshape(len(test))
    fpr, tpr, thresholds = roc_curve(test[:,3], predict_result, pos_label=1)
    plt.plot(fpr, tpr, linewidth=2, label = 'ROC of LM') #作出ROC曲线
    plt.xlabel('False Positive Rate') #坐标轴标签
    plt.ylabel('True Positive Rate') #坐标轴标签
    plt.ylim(0,1.05) #边界范围
    plt.xlim(0,1.05) #边界范围
    plt.legend(loc=4) #图例
    plt.show() #显示作图结果
  • 相关阅读:
    Windows下安装并设置Redis
    Android平台相机接口的应用
    Struts(十五):主题
    Struts(十四):通用标签-form表单
    Struts(十三):通用标签
    Struts(十二):异常处理:exception-mapping元素
    Struts(十一):OGNL表达式(二)
    Struts(十):OGNL表达式(一)
    A*算法
    Centos:如何查找安装的jdk的目录
  • 原文地址:https://www.cnblogs.com/ggzhangxiaochao/p/9115295.html
Copyright © 2011-2022 走看看