zoukankan      html  css  js  c++  java
  • Tensorflow2.0笔记26——参数提取,写至文本

    Tensorflow2.0笔记

    本博客为Tensorflow2.0学习笔记,感谢北京大学微电子学院曹建老师

    2.4 参数提取,写至文本

    1.提取可训练参数

    model.trainable_variables 模型中可训练的参数

    2.设置print输出格式

    np.set_printoptions(precision=小数点后按四舍五入保留几位,threshold=数组元素数量少于或等于门槛值,打印全部元素;否则打印门槛值+1 个元素,中间用省略号补充)

    >>> np.set_printoptions(precision=5)

    >>> print(np.array([1.123456789]))

    [1.12346]

    >>> np.set_printoptions(threshold=5)

    >>> print(np.arange(10))

    [0 1 2 … , 7 8 9]

    注:precision=np.inf 打印完整小数位;threshold=np.nan 打印全部数组元素。

    import tensorflow as tf
    import os
    import numpy as np
    np.set_printoptions(threshold=np.inf)
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy'])
    checkpoint_save_path = "./checkpoint/mnist.ckpt"
    if os.path.exists(checkpoint_save_path + '.index'):
        print('-------------load the model-----------------')
        model.load_weights(checkpoint_save_path)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                     save_weights_only=True,
                                                     save_best_only=True)
    history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                        callbacks=[cp_callback])
    model.summary()
    print(model.trainable_variables)
    file = open('./weights.txt', 'w')
    for v in model.trainable_variables:
        file.write(str(v.name) + '
    ')
        file.write(str(v.shape) + '
    ')
        file.write(str(v.numpy()) + '
    ')
    file.close()
    

    image-20210623203949189

  • 相关阅读:
    HDU 2852 KiKi's K-Number (主席树)
    HDU 2089 不要62
    Light oj 1140 How Many Zeroes?
    Bless You Autocorrect!
    HDU 6201 transaction transaction transaction
    HDU1561 The more ,The better (树形背包Dp)
    CodeForces 607B zuma
    POJ 1651 Mulitiplication Puzzle
    CSUOJ 1952 合并石子
    Uva 1599 Ideal path
  • 原文地址:https://www.cnblogs.com/wind-and-sky/p/14924822.html
Copyright © 2011-2022 走看看