zoukankan      html  css  js  c++  java
  • Debug --> Variable,Tensor,Numpy的转换

    尝试输出keras模型参数的时候,需要解决的问题:

     1 import tensorflow.compat.v1 as tf
     2 tf.disable_v2_behavior()
     3 import numpy as np
     4 weight = tf.get_variable(name='weights',initializer=tf.random_normal([5,2], stddev=0.01))
     5 with tf.Session() as sess:
     6     sess.run(tf.global_variables_initializer())
     7     print('------------------打印出已经初始化之后的Variable的值------------------------------')
     8     print(sess.run(weight))
     9     print('----------weight的类型------------')
    10     print(type(weight))
    11     # Variable转换为Tensor
    12     # Variable类型转换为tensor类型(无论是numpy转换为Tensor还是Variable转换为Tensor都可以使用tf.convert_to_tensor)
    13     data_tensor = tf.convert_to_tensor(weight) 
    14     # 打印出Tensor的值(由Variable转化而来)
    15     print('------------------Variable转化为Tensor,打印出Tensor的值--------------------------')
    16     print(sess.run(data_tensor))
    17     # tensor转化为numpy
    18     print('-------------------tensor转换为numpy,打印出numpy的值-----------------')
    19     data_numpy = data_tensor.eval()
    20     print(data_numpy)
    21     print('------------------numpy转换为Tensor---------------------------')
    22     ten = tf.convert_to_tensor(data_numpy)
    23     print(ten)
    24     print(sess.run(ten))
    25     # tensor转化为Variable(其实是Variable继承Tensor的结构,但是没有值
    26     print('---------------------tensor转换为Variable(需要重新进行初始化)----------------------')
    27     v = tf.Variable(data_tensor) # 此时Variable继承的是Tensor的结构,至于Variable的值,需要重新进行initialize
    28     sess.run(tf.global_variables_initializer())
    29     print(sess.run(weight)) # 此时输出的weight和v的结构是相同的,但是值是不同的。
    30     print(sess.run(v))
    31     
    32 #     tf.enable_eager_execution(
    33 #     config=None,
    34 #     device_policy=None,
    35 #     execution_mode=None
    36 #     )
    37     # Variable转换为numpy(也是使用eval)
    38     print('---------------Variable转换为numpy(也是使用eval)--------------------')
    39     data_numpy2 = weight.eval()
    40     print(data_numpy2)

    1.模型保存

    model.save_model()可以保存网络结构权重以及优化器的参数
    model.save_weights() 仅仅保存权重

    2.模型加载

    from keras.models import load_model
    load_model():只能load 由save_model保存的,将模型和weight全load进来

    model.load_weights(self, filepath, by_name=False):在加载权重之前,model必须编译好

    3.sequential 和functional

    序列式模型只能有单输入单输出,函数式模型可以有多个输入输出

    4.model类

    因为是继承, model对象有 container和layer的所有方法,可以用model对象访问下面三个类的所有方法

    Model(Container)containerlayer
    fit summary get_input_at(node_index)
    evaluate get_layer get_config()
    predict get_weights compute_mask(x, mask)
    train on batch set_weights get_input_mask_at(node_index)
    test_on_batch get_config get_output_at(node_index)
    predict_on_batch compute_output_shape  
    evaluate_generator    
    predict_generator    
     

    5.打印各层权重

    layer.get_weights返回的是没有名字的权重array,Model.get_weights() 是他们的拼接,也没有名字,利用layer.weights 可以访问到后台的变量

    1 #打印各层名字,权重的形状
    2 for layer in model.layers:
    3         for weight in layer.weights:
    4             print weight.name,weight.shape
    
    

    上面输出的weight是Var类型,下面给出另一种方法,输出的weight是np.Array类型:

    1 names = [weight.name for layer in model.layers for weight in layer.weights]
    2 weights = model.get_weights()
    3 for name, weight in zip(names, weights):
    4     print(name, weight.shape) 


    To see I can not see, to know I do not know.
  • 相关阅读:
    Python学习之余,摸摸鱼
    Python 实现斐波那契数
    Linux下为什么目录的大小总是4096
    Python的精髓居然是方括号、花括号和圆括号!
    为什么说Python是最伟大的语言?看图就知道了!
    前端常用知识(会更新)
    Mysql 约束
    Navicat 安装
    Java后台将CTS格式转为标准日期时间格式返回给前端
    MySQL数据库报错“Zero date value prohibited”
  • 原文地址:https://www.cnblogs.com/aluomengmengda/p/14679858.html
Copyright © 2011-2022 走看看