zoukankan      html  css  js  c++  java
  • Tensorflow2.0语法

    转自 https://segmentfault.com/a/1190000021181739

    前言

    keras接口大都实现了 _call_ 方法。
    母类 _call_ 调用了 call()。
    因此下面说的几乎所有模型/网络层 都可以在定义后,直接像函数一样调用。
    eg:

    模型对象(参数) 
    网络层对象(参数)

    我们还可以实现继承模板

    导入

    from tensorflow import keras

    metrics (统计平均)

    里面有各种度量值的接口
    如:二分类、多分类交叉熵损失容器,MSE、MAE的损失值容器, Accuracy精确率容器等。
    下面以Accuracy伪码为例:

    acc_meter = keras.metrics.Accuracy() # 建立一个容器
    for _ in epoches:
        for _ in batches:
            y = ...
            y_predict = ...
            acc_meter.update_state(y, y_predict) # 每次扔进去数据,容器都会自动计算accuracy,并储存
        
            if times % 100 == 0: # 一百次一输出, 设置一个阈值/阀门
                print(acc_meter.result().numpy())   # 取出容器内所有储存的数据的,均值准确率
        acc_meter。reset_states()    # 容器缓存清空, 下一epoch从头计数。

    激活函数+损失函数+优化器

    导入方式:

    keras.activations.relu()    # 激活函数:以relu为例,还有很多
    keras.losses.categorical_crossentropy() # 损失函数:以交叉熵为例,还有很多
    keras.optimizers.SGD()      # 优化器:以随机梯度下降优化器为例
    keras.callbacks.EarlyStopping()  # 回调函数: 以‘按指定条件提前暂停训练’回调为例

    Sequential(继承自Model)属于模型

    模型定义方式

    定义方式1:

    model = keras.models.Sequential( [首层网络,第二层网络。。。] )

    定义方式1:

    model = keras.models.Sequential()
    model.add(首层网络)
    model.add(第二层网络)

    模型相关回调配置

    logdir = 'callbacks'
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    save_model_file = os.path.join(logdir, 'mymodel.h5')
    
    callbacks = [
        keras.callbacks.TensorBoard(logdir),    # 写入tensorboard
        keras.callbacks.ModelCheckpoint(output_model_file, save_best_only=True),  # 模型保存
        keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)  # 按条件终止模型训练
        # 验证集,每次都会提升,如果提升不动了,提升小于这个min_delta阈值,则会耐心等待5次。
        # 5次过后,要是还提升这么点。就提前结束。
    ]
    # 代码写在这里,如何传递调用, 下面 “模型相关量度配置” 会提到
    

    模型相关量度配置:((损失,优化器,准确率等)

    说明,下面的各种量度属性,可通过字符串方式,也可通过上面讲的导入实例化对象方式。

    model.compile(
        loss="sparse_categorical_crossentropy",    # 损失函数,这是字符串方式
        optimizer= keras.optimizers.SGD()          # 这是实例化对象的方式,这种方式可以传参
        metrics=['accuracy']  # 这项会在fit()时打印出来
    )  # compile() 操作,没有真正的训练。
    model.fit(
        x,y,
        epochs=10,                              # 反复训练 10 轮
        validation_data = (x_valid,y_valid),    # 把划分好的验证集放进来(fit时打印loss和val)
        validation_freq = 5,                    # 训练5次,验证一次。  可不传,默认为1。
        callbacks=callbacks,                    # 指定回调函数, 请衔接上面‘模型相关回调配置’
        
    )   # fit()才是真正的训练 

    模型 验证&测试

    一般我们会把 数据先分成三部分(如果用相同的数据,起不到测试和验证效果,参考考试作弊思想):

    1. 训练集: (大批量,主体)
    2. 测试集: (模型所有训练结束后, 才用到)
    3. 验证集: (训练的过程种就用到)

    说明1:(如何分离?)

    1. 它们的分离是需要(x,y)组合在一起的,如果手动实现,需要随机打散、zip等操作。
    2. 但我们可以通过 scikit-learn库,的 train_test_split() 方法来实现 (2次分隔)
    3. 可以使用 tf.split()来手动实现

    具体分离案例:参考上一篇文章: https://segmentfault.com/a/11...

    说明2:(为什么我们有了测试集,还需要验证集?)

    1. 测试集是用来在最终,模型训练成型后(参数固定),进行测试,并且返回的是预测的结果值!!!!
    2. 验证集是伴随着模型训练过程中而验证)

    代码如下:

    loss, accuracy = model.evaluate( (x_test, y_test) ) # 度量, 注意,返回的是精度指标等
    target = model.predict( (x_test, y_test) )          # 测试, 注意,返回的是 预测的结果!

    可用参数

    model.trainable_variables    # 返回模型中所有可训练的变量
    # 使用场景: 就像我们之前说过的 gradient 中用到的 zip(求导结果, model.trainable_variables)

    自定义Model

    Model相当于母版, 你继承了它,并实现对应方法,同样也能简便实现模型的定义。

    自定义Layer

    同Model, Layer也相当于母版, 你继承了它,并实现对应方法,同样也能简便实现网络层的定义。

    模型保存与加载

    方法1:之前callback说的

    方法2:只保存weight(模型不完全一致)

    保存:

    model = keras.Sequential([...])
    ...
    model.fit()
    model.save_weights('weights.ckpt')

    加载:

    假如在另一个文件中。(当然要把保存的权重要复制到本地目录)
    model = keras.Sequential([...])    # 此模型构建必须和保存时候定义结构的一模一样的!
    model.load_weights('weights.ckpt')
    model.evaluate(...)
    model.predict(...)
    

    方法3:保存整个模型(模型完全一致)

    保存:

    model = keras.Sequential([...])
    ...
    model.fit()
    model.save('model.h5')    # 注意 这里变了,是 save

    加载:(直接加载即可,不需要重新复原建模过程)

    假如在另一个文件中。(当然要把保存的模型要复制到本地目录)
    model = keras.models.load_model('model.h5')  # load_model是在 keras.models下
    model.evaluate(...)
    model.predict(...)

    方法4:导出可供其他语言使用(工业化)

    保存: (使用tf.saved_model模块)

    model = keras.Sequential([...])
    ...
    model.fit()
    tf.saved_model.save(model, '目录')

    加载:(使用tf.saved_model模块)

    model = tf.saved_model.load('目录')
  • 相关阅读:
    爬虫之爬取网贷之家在档P2P平台基本数据并存入数据库
    Python抓取第一网贷中国网贷理财每日收益率指数
    div左右布局
    IIS7.0+SqlServer2012,进行.net网站发布的安装全过程
    SpringMVC+Mybatis+Mysql实战项目学习环境搭建
    文本框字符长度动态统计
    html里面自定义弹出窗口
    windows下取linux系统里面的文件
    网页中的电话号码实现一键直呼
    测试
  • 原文地址:https://www.cnblogs.com/whw1314/p/12121908.html
Copyright © 2011-2022 走看看