zoukankan      html  css  js  c++  java
  • keras 八股文

    六步法

    1 import
    2 train,test
    3 model = Sequential()
    4 model.compile()
    5 model.fit()
    6 model.evaluate()

    Sequential()方法是一个容器,描述了神经网络的网络结构,在Sequential()的输入参数中描述从输入层到输出层的网络结构

    model = tf.keras.models.Sequential([网络结构])  #描述各层网络

    网络结构举例:

    拉直层:tf.keras.layers.Flatten() #拉直层可以变换张量的尺寸,把输入特征拉直为一维数组,是不含计算参数的层

    全连接层:tf.keras.layers.Dense(神经元个数,

                                                          activation = "激活函数“,

                                                          kernel_regularizer = "正则化方式)

    其中:activation可选 relu 、softmax、 sigmoid、 tanh等

               kernel_regularizer可选 tf.keras.regularizers.l1() 、tf.keras.regularizers.l2()

    卷积层:tf.keras.layers.Conv2D(filter = 卷积核个数,

                                                       kernel_size = 卷积核尺寸,

                                                       strides = 卷积步长,

                                                       padding = ”valid“ or "same")

    LSTM层:tf.keras.layers.LSTM()

    #导入模块
    import tensorflow as tf
    from sklearn import datasets
    import numpy as np
    #第二步,加载数据集
    x_train = datasets.load_iris().data
    y_train = datasets.load_iris().target
    
    np.random.seed(116)#设置随机种子,每次结果都一样方便对照
    np.random.shuffle(x_train)#使用shuffle()方法,让x_train乱序
    np.random.seed(116)#设置随机种子,每次结果一样,方便对照
    np.random.shuffle(y_train)#使用shuffle()方法,让输入y_train乱序
    tf.random.set_seed(116)
    #第三步,models.Sequential()
    model = tf.keras.models.Sequential()#使用models.Sequential()搭建神经网络
    tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
    #第四步,model.compile()编译
    model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),#使用SGD优化器,学习率为0.1
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),#配置损失函数
                  metrics=['sparse_categorical_accuracy'])#标注网络评价指标
    #第五步,训练
    model.fit(x_train, y_train, #告知训练集的输入以及标签
              batch_size=32, #每一批的batch的大小为32
              epochs=500, #迭代次数为500
              validation_split=0.2,#从训练集中选20%作为测试集
              validation_freq=20#没迭代20次训练集,要在测试集中验证一次准确率
              )
    #第六步,打印网络结果和参数
    model.summary()
  • 相关阅读:
    Twitter Storm安装配置(Ubuntu系统)单机版
    Ubuntu下安装配置JDK1.7
    JS性能优化
    JavaScript禁用页面刷新
    pomelo获取客户端IP
    MySQL数据库工具类之——DataTable批量加入MySQL数据库(Net版)
    MySQL5.6忘记root密码(win平台)
    清空文件下的SVN控制文件
    Windows平台搭建NodeJs开发环境以及HelloWorld展示—图解
    Unity3D默认的快捷键
  • 原文地址:https://www.cnblogs.com/hsy1941/p/13827712.html
Copyright © 2011-2022 走看看