zoukankan      html  css  js  c++  java
  • 手动实现TensorFlow的训练过程:示例

    参考文献:Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems

    l2_reg = keras.regularizers.l2(0.05)
    model = keras.models.Sequential([
        keras.layers.Dense(30, activation="elu", kernel_initializer="he_normal",
                           kernel_regularizer=l2_reg),
        keras.layers.Dense(1, kernel_regularizer=l2_reg)
    ])
    
    n_epochs = 5
    batch_size = 32
    n_steps = len(X_train) // batch_size
    optimizer = keras.optimizers.Nadam(lr=0.01)
    loss_fn = keras.losses.mean_squared_error
    mean_loss = keras.metrics.Mean()
    metrics = [keras.metrics.MeanAbsoluteError()]
    
    for epoch in range(1, n_epochs + 1):
        print("Epoch {}/{}".format(epoch, n_epochs))
        for step in range(1, n_steps + 1):
            X_batch, y_batch = random_batch(X_train_scaled, y_train)
            with tf.GradientTape() as tape:
                y_pred = model(X_batch)
                main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
                a = main_loss
                b = model.losses
                loss = tf.add_n([main_loss] + model.losses)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            for variable in model.variables:
                if variable.constraint is not None:
                    variable.assign(variable.constraint(variable))
            c = loss
            mean_loss(loss)
            for metric in metrics:
                metric(y_batch, y_pred)
            print_status_bar(step * batch_size, len(y_train), mean_loss, metrics)
        print_status_bar(len(y_train), len(y_train), mean_loss, metrics)
        for metric in [mean_loss] + metrics:
            metric.reset_states()
    

    由于模型中存在regularizer,model.losses是每层layer中的regularization loss。总的loss等于loss function + regularization loss。

  • 相关阅读:
    非常不错的学习网站——技术胖
    Vue中使用mescroll.js实现下拉刷新
    2020.10.07【NOIP提高A组】模拟
    2020.10.06【NOIP提高A组】模拟 总结
    2020.09.19【NOIP提高A组】模拟
    2020.09.12【NOIP提高A组】模拟
    GMOJ 4417. 【HNOI2016模拟4.1】神奇的字符串 题解
    上下界网络流学习笔记
    GMOJ 3571. 【GDKOI2014】内存分配 题解
    [模板]人工栈
  • 原文地址:https://www.cnblogs.com/yaos/p/12755499.html
Copyright © 2011-2022 走看看