zoukankan      html  css  js  c++  java
  • Tensorflow2.0学习(6)---Keras

    来自书籍:TensorFlow深度学习

    一、常见功能模块

    • 经典数据集加载函数
    • 网络层类
    • 模型容器
    • 损失函数类
    • 优化器类
    • 经典模型类

    1、网络层类:

    两种实现方式:张量方式(tf.nn)和层方式(tf.keras.layers)

    如实现Softmax层:

      • tf.nn.softmax函数实现;
      • layers.Softmax(axis)类搭建Softmax网络层;

     2、网络容器

    Keras提供的网络容器Sequential将多个网络层封装成一个大网络模型。

    2层的全连接层加上单独的激活函数,并用Sequential封装成一个网络。

    注释:

    build函数指定输入大小,即可自动创建所有层的内部张量。

    summary函数打印出网络结构和参数量。

     3、模型装配、训练、预测

    (1)keras.Model类和keras.layers.Layer类

    • Layer类是网络层的母类,定义了网络层的一些常见功能:添加权值、管理权值列表等。
    • Model类是网络的母类,可添加了保存模型、加载模型、训练与测试模型等功能。

    (2)简述一下模型创建训练过程:

    • 创建网络
    • 循环迭代数据集多个Epoch,每次按批产生训练数据
    • 前向传播
    • 通过损失函数计算误差值
    • 反向传播自动计算梯度
    • 更新网络参数

    (3)模型装配:

    compile函数:指定网络使用的优化器对象、损失函数类型,评价指标等设定

    (4)模型训练验证:

    fit函数:模型训练验证,通过fit()函数送入待训练的数据集和验证用的数据集

     

     (5)模型预测:

    Model.predict(x):完成模型预测。

     4、模型保存与加载

    三种保存与加载方式:张量方式、网络方式、SavedModel方式

    (1)张量方式

    Model.save_weights(path):将当前的网络参数保存到path文件上。

    Model.load_weights(path):加载网络参数,但需提前创建网络对象。

     

     (2)网络方式

    Model.save(path):将模型的结构以及模型的参数保存到path文件上。

    keras.models.load_model(path):加载模型,不需要提前创建网络对象。

     (3)SavedModel方式

    tf.saved_model.save(network, path):将network模型以SavedModel方式保存到path中

    tf.saved_model.load函数加载模型

     5、自定义网络

    • 自定义网络层类时,需要继承自layers.Layer基类;【即单层网络层的结构】
    • 自定义的网络类时,需要继承自keras.Model基类。【即整体模型结构】

    自定义网络层:

     自定义网络:

     6、经典模型类

    •  ResNet

     7、测量工具

    keras.metrics模块

    4个步骤:

    • 新建测量器
    • 写入数据
    • 读取统计数据
    • 清零测量器

    (1)新建测量器

    常用的测量器:Mean类、Accuracy类、CosineSimilarity类等。

     

     (2)写入数据

    update_state函数

    在每个step结束时采集一次loss值,以下代码放置在每个Batch运算结束后,测量器会自动根据采用的数据来统计平均值。

     

     (3)读取统计信息

    result函数

     (4)清楚状态

    reset_states()函数

     

     8、可视化

    TensorBoard

    tf.summary.create_file_writer:创建监控对象类实例,并指定监控数据的写入目录。

    tf.summary.scalar:记录监控数据,并指定时间戳step参数。

    tf.summary.histogram:查看张量数据的直方图分布

    tf.summary.text:打印文本信息等

    对于图片类型的数据:

     浏览器查看:

    cmd执行 tensorboard --logdir path:指定web后端监控的文件目录path。

    二、过拟合

    1、正则化

    L1正则化:

     L2正则化:

     2、Dropout

     3、数据增强

     (1)旋转

     (2)翻转

    (3)裁剪

  • 相关阅读:
    ES6, CommonJS, AMD, CMD,UMD模块化规范介绍及使用
    前端项目开发框架选型需考虑的4个方面
    初识webSocket及其使用
    动态组件 —— 2种方式实现动态组件的切换
    mac下anaconda安装新包
    新版docker设置国内镜像
    记一次解决Original error: UiAutomator quit before it successfully launched
    linux clion cmakelisits undefined reference 未定义引用
    苹果设备插入PC不能识别问题解决办法
    用Cucumber理解BDD行为驱动开发
  • 原文地址:https://www.cnblogs.com/Lee-yl/p/12573609.html
Copyright © 2011-2022 走看看