zoukankan      html  css  js  c++  java
  • Tensorflow2.0笔记18——函数用法介绍

    Tensorflow2.0笔记

    本博客为Tensorflow2.0学习笔记,感谢北京大学微电子学院曹建老师

    1.3 函数用法介绍

    ——tf.keras.models.Sequential()

    ​ Sequential 函数是一个容器,描述了神经网络的网络结构,在 Sequential函数的输入参数中描述从输入层到输出层的网络结构。

    ​ 例如:

    拉直层:tf.keras.layers.Flatten() 拉直层可以变换张量的尺寸,把输入特征拉直为一维数组,是不含计算参数的层。

    全连接层:tf.keras.layers.Dense( 神经元个数, activation=”激活函数”, kernel_regularizer=”正则化方式”) 其中: activation(字符串给出)可选 relu、softmax、sigmoid、tanh 等,kernel_regularizer 可选 tf.keras.regularizers.l1()、 tf.keras.regularizers.l2()

    卷积层:tf.keras.layers.Conv2D( filter = 卷积核个数, kernel_size = 卷积核尺寸, strides = 卷积步长, padding = “valid” or “same”)

    STM 层:tf.keras.layers.LSTM()。 本章只使用拉直层和全连接层,卷积层和循环神经网络层将在之后的章节介绍。

    ——Model.compile( optimizer = 优化器, loss = 损失函数, metrics = [“准确率”])

    Compile 用于配置神经网络的训练方法,告知训练时使用的优化器、损失函数和准确率评测标准。

    ​ 其中:

    ​ optimizer 可以是字符串形式给出的优化器名字,也可以是函数形式,使用函数形式可以设置学习率、动量和超参数。 可选项包括:

    ​ sgd’or tf.optimizers.SGD( lr=学习率, decay=学习率衰减率, momentum=动量参数)

    ​ ‘adagrad’or tf.keras.optimizers.Adagrad(lr=学习率, decay=学习率衰减率)

    ​ ‘adadelta’or tf.keras.optimizers.Adadelta(lr=学习率, decay=学习率衰减率)

    ​ ‘adam’or tf.keras.optimizers.Adam (lr=学习率, decay=学习率衰减率)

    Loss 可以是字符串形式给出的损失函数的名字,也可以是函数形式。可选项包括:

    ​ ‘mse’or tf.keras.losses.MeanSquaredError()

    ​ ‘sparse_categorical_crossentropy or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

    ​ 损失函数常需要经过 softmax 等函数将输出转化为概率分布的形式。from_logits 则用来标注该损失函数是否需要转换为概率的形式,取 False 时表示转化为概率分布,取 True 时表示没有转化为概率分布,直接输出。

    Metrics 标注网络评测指标。可选项包括:

    ​ ‘accuracy’:y_和 y 都是数值,如 y_=[1] y=[1]。

    ​ ‘categorical_accuracy’:y_和 y 都是以独热码和概率分布表示。 如 y_=[0, 1, 0], y=[0.256, 0.695, 0.048] 。

    ‘sparse_ categorical_accuracy’:y_是以数值形式给出,y 是以独热码形式给出。 如 y_=[1],y=[0.256, 0.695, 0.048]。 
    

    ——model.fit(训练集的输入特征, 训练集的标签, batch_size, epochs,validation_data = (测试集的输入特征,测试集的标签),validataion_split = 从测试集划分多少比例给训练集, validation_freq = 测试的 epoch 间隔次数)

    ​ fit 函数用于执行训练过程。

    ——model.summary()

    summary 函数用于打印网络结构和参数统计。

    image-20210622202321988

    ​ 上图是 model.summary()对鸢尾花分类网络的网络结构和参数统计,对于一个输入为 4 输出为 3 的全连接网络,共有 15 个参数。

  • 相关阅读:
    poj1580
    poj1607
    poj1313
    poj1314
    c语言之extern和static
    C笔记(一)
    搭建Linux高可用性集群(第一天)
    利用回调函数实现泛型算法
    关于SQL server中的 identity
    SQL(一)
  • 原文地址:https://www.cnblogs.com/wind-and-sky/p/14920253.html
Copyright © 2011-2022 走看看