zoukankan      html  css  js  c++  java
  • 手动实现TensorFlow的训练过程:示例

    参考文献:Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems

    l2_reg = keras.regularizers.l2(0.05)
    model = keras.models.Sequential([
        keras.layers.Dense(30, activation="elu", kernel_initializer="he_normal",
                           kernel_regularizer=l2_reg),
        keras.layers.Dense(1, kernel_regularizer=l2_reg)
    ])
    
    n_epochs = 5
    batch_size = 32
    n_steps = len(X_train) // batch_size
    optimizer = keras.optimizers.Nadam(lr=0.01)
    loss_fn = keras.losses.mean_squared_error
    mean_loss = keras.metrics.Mean()
    metrics = [keras.metrics.MeanAbsoluteError()]
    
    for epoch in range(1, n_epochs + 1):
        print("Epoch {}/{}".format(epoch, n_epochs))
        for step in range(1, n_steps + 1):
            X_batch, y_batch = random_batch(X_train_scaled, y_train)
            with tf.GradientTape() as tape:
                y_pred = model(X_batch)
                main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
                a = main_loss
                b = model.losses
                loss = tf.add_n([main_loss] + model.losses)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            for variable in model.variables:
                if variable.constraint is not None:
                    variable.assign(variable.constraint(variable))
            c = loss
            mean_loss(loss)
            for metric in metrics:
                metric(y_batch, y_pred)
            print_status_bar(step * batch_size, len(y_train), mean_loss, metrics)
        print_status_bar(len(y_train), len(y_train), mean_loss, metrics)
        for metric in [mean_loss] + metrics:
            metric.reset_states()
    

    由于模型中存在regularizer,model.losses是每层layer中的regularization loss。总的loss等于loss function + regularization loss。

  • 相关阅读:
    P2207 Photo
    P1022 计算器的改良
    P1003 铺地毯
    P3014 [USACO11FEB]牛线Cow Line && 康托展开
    P4180 【模板】严格次小生成树[BJWC2010]
    P2776 [SDOI2007]小组队列
    P2426 删数
    P1948 [USACO08JAN]电话线Telephone Lines
    P1978 集合
    P1564 膜拜
  • 原文地址:https://www.cnblogs.com/yaos/p/12755499.html
Copyright © 2011-2022 走看看