zoukankan      html  css  js  c++  java
  • TF代码片段

    • keys_len 如果是None的时候, 注释代码失败。
    # keys_len = keys.get_shape()[1]
    # queries = K.repeat_elements(query, keys_len, 1)
    
    keys_len = tf.shape(keys)[1]
    multiples = tf.stack([1 if i != 1 else keys_len   for i in range(len(query.get_shape()))])
    queries = tf.tile(query, multiples)
    
    • 代码优化。 重复key很多且embedding 矩阵很大。 用tf.contrib.layers.embedding_lookup_unique 代替 embedding_lookup, 减少通讯量。
    • 对稀疏矩阵的列分组
    n_cols = 2
    n_rows = 4
    
    I = tf.constant([0,0,0,0, 1,1,1, 3,3,3,3], dtype=tf.int64)
    J = tf.constant([0,0,1,1, 1,1,0, 1,0,0,0], dtype=tf.int64)
    
    V = tf.constant([1,1,1,1, 1,1,1, 1,1,1,1], shape=(11,1),dtype=tf.float32)
    Idx = I * n_cols + J
    
    B = tf.unsorted_segment_sum(V, Idx,  n_cols * n_rows)
    tf.reshape(B, [n_rows, n_cols]).eval()
    '''
    array([[2., 2.],
           [1., 2.],
           [0., 0.],
           [3., 1.]], dtype=float32)
    '''
    
  • 相关阅读:
    把数组排成最小的数
    整数中1出现的次数
    连续子数组的最大和
    快速排序
    penCV入门
    OpenCV视频操作
    linux下导入oracle数据表
    js工作备注
    oracle创建默认表空间---重要
    oracle的导入导出
  • 原文地址:https://www.cnblogs.com/bregman/p/13999900.html
Copyright © 2011-2022 走看看