zoukankan      html  css  js  c++  java
  • tensorflow用法记录

    使用 embedding 变量

    import tensorflow as tf
    import numpy as np
    
    sess = tf.InteractiveSession()
    
    M = list('ABCD')
    table = tf.contrib.lookup.index_table_from_tensor(
        mapping=tf.constant(M), num_oov_buckets=1, default_value=-1)
    
    # 包含多个ID
    IDs = tf.Variable(["A|B|C", "C|D|A|B","C|A|A|B|E|E"] )
    embedding_mat = tf.constant(np.arange(20, dtype=float).reshape((20,1)))
    
    # 查找embedding变量,并聚合
    hs = tf.string_split(IDs, '|')
    sp_ID = tf.SparseTensor(hs.indices, table.lookup(hs.values),hs.dense_shape)
    AB = tf.nn.embedding_lookup_sparse(embedding_mat, sp_ID, sp_weights=None, combiner='mean')
    AC = tf.nn.embedding_lookup_sparse(embedding_mat, sp_ID, sp_weights=None, combiner='sum')
    
    
    tf.tables_initializer().run()
    tf.global_variables_initializer().run()
    
    print(AB.eval() )
    print()
    print(AC.eval())
    

    结果

    [[ 1.        ]
     [ 1.5       ]
     [ 1.83333333]]
    
    [[  3.]
     [  6.]
     [ 11.]]
    

    修改learning_rate

    开始用 feed_dict 方式传入,但通过tensorboard发现实际上没改成。搜索了一通发现,下面方式最简单。

    self.learning_rate /= 2.0
    _ = sess.run(self.learning_rate_var.assign(self.learning_rate))
    

    tf.keras的用法

    • Layer 自动生成mask 张量,并在下一个layer中继承,通过attribute属性_keras_mask 实现。 在序列模型上非常有用。
    from tensorflow.python.keras.layers import Embedding
    e = Embedding(100, 3, mask_zero=True)(tf.constant([1,3,4,0]))
    e._keras_mask.eval()
    # 输出 array([ True,  True,  True, False])
    
    • 属性的用法
    setattr(e, '_keras_mask0', 1)
    getattr(e, '_keras_mask0')
    
    • mask 举例说明
    import tensorflow as tf
    from tensorflow.python.keras.layers import Input
    from tensorflow.python.keras.layers import Embedding
    from tensorflow.compat.v1 import logging
    logging.set_verbosity(logging.ERROR)
    
    
    class NoMask(tf.keras.layers.Layer):
        def __init__(self, **kwargs):
            super(NoMask, self).__init__(**kwargs)
    
        def build(self, input_shape):
            # Be sure to call this somewhere!
            super(NoMask, self).build(input_shape)
    
        def call(self, x, mask=None, **kwargs):
            return x
    
        def compute_mask(self, inputs, mask):
            return None
    
    
    class hasMask(tf.keras.layers.Layer):
        def __init__(self, **kwargs):
            super(hasMask, self).__init__(**kwargs)
    
        def build(self, input_shape):
            super(hasMask, self).build(input_shape)
    
        def call(self, x, mask=None, **kwargs):
            return x
    
        def compute_mask(self, inputs, mask):
            return mask
    
    
    class hasMask2(tf.keras.layers.Layer):
        def __init__(self, supports_masking=False, **kwargs):
            super(hasMask2, self).__init__(**kwargs)
            self.supports_masking = supports_masking
            
        def build(self, input_shape):
            super(hasMask2, self).build(input_shape)
    
        def call(self, x, mask=None, **kwargs):
            return x
    
    
    if __name__ == "__main__":
        e = Embedding(100, 3, mask_zero=True)(Input(shape=(1,), name='cate_id', dtype='int32'))
        print(e._keras_mask, hasMask()(e)._keras_mask, hasMask2(True)(e)._keras_mask, hasMask2(False)(e)._keras_mask, NoMask()(e)._keras_mask)
        print("*" * 50)
        e = Embedding(100, 3, mask_zero=False)(Input(shape=(1,), name='cate_id', dtype='int32'))
        print(e._keras_mask, hasMask()(e)._keras_mask, hasMask2(True)(e)._keras_mask, hasMask2(False)(e)._keras_mask, NoMask()(e)._keras_mask)
    

    输出

    Tensor("embedding/NotEqual:0", shape=(?, 1), dtype=bool) Tensor("embedding/NotEqual:0", shape=(?, 1), dtype=bool) Tensor("embedding/NotEqual:0", shape=(?, 1), dtype=bool) None None
    **************************************************
    None None None None None
    
  • 相关阅读:
    系统信息查看命令
    item pipeline 实例:爬取360摄像图片
    scrapy之 downloader middleware
    scrapy 中用selector来提取数据的用法
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
  • 原文地址:https://www.cnblogs.com/bregman/p/7845160.html
Copyright © 2011-2022 走看看