zoukankan      html  css  js  c++  java
  • tf更新tensor/自定义层

    修改Tensor特定位置的值

    stack overflow 中提到的方案。
    TensorFlow不让你直接单独改指定位置的值,但是留了个歪门儿,就是tf.scatter_update这个方法,它可以批量替换张量某一维上的所有数据。

    def set_value(matrix, x, y, val):
        # 提取出要更新的行
        row = tf.gather(matrix, x)
        # 构造这行的新数据
        new_row = tf.concat([row[:y], [val], row[y+1:]], axis=0)
        # 使用 tf.scatter_update 方法进正行替换
        matrix.assign(tf.scatter_update(matrix, x, new_row)) 
    

    但是这么做有没什么缺点呢?有,那就是慢,特别是矩阵很大的时候,那是真心的慢。
    TensorFlow是对张量运算(其实二维的就是矩阵运算)有速度优化的,能不能将张量修改的操作变成一个普通的张量运算呢?能,再构建一个差值张量然后做个加法,哎,又是一条旁门邪道。

    def set_value(matrix, x, y, val):
        # 得到张量的宽和高,即第一维和第二维的Size
        w = int(matrix.get_shape()[0])
        h = int(matrix.get_shape()[1])
        # 构造一个只有目标位置有值的稀疏矩阵,其值为目标值于原始值的差
        val_diff = val - matrix[x][y]
        diff_matrix = tf.sparse_tensor_to_dense(tf.SparseTensor(indices=[x, y], values=[val_diff], dense_shape=[w, h]))
        # 用 Variable.assign_add 将两个矩阵相加
        matrix.assign_add(diff_matrix)
    

    cs20si课程作业1的第3题 后一种方法的效率大概提升了4倍。

    Shuffling input files with tensorflow Datasets

    问题

    按文件列表顺序读取

    BUFFER_SIZE = 1000 # arbitrary number
    # define filenames somewhere, e.g. via glob
    dataset = tf.data.TFRecordDataset(filenames).shuffle(BUFFER_SIZE)
    

    shuffle文件,然后读取

    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.shuffle(BUFFER_SIZE) # doesn't need to be big
    dataset = dataset.flat_map(tf.data.TFRecordDataset)
    dataset = dataset.map(decode_example, num_parallel_calls=5) # add your decoding logic here
    # further processing of the dataset
    

    同时从多个文件读取

    dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
    

    TF自定义梯度

    自定义梯度

    多个op
    See also tf.RegisterGradient which registers a gradient function for a primitive TensorFlow operation. tf.custom_gradient on the other hand allows for fine grained control over the gradient computation of a sequence of operations.

    keras 不支持 去用pytorch吧

  • 相关阅读:
    仿百度翻页(转)
    文字顺时针旋转90度(纵向)&古诗词排版
    微信小程序使用canvas绘制图片的注意事项
    PHP即时实时输出内容
    使用Android Studio遇到的问题
    RuntimeError: Model class users.models.UserProfile doesn't declare an explicit app_label and isn't in an application in INSTALLED_APPS.
    drf中的各种view,viewset
    代码审计:covercms 1.6
    windows下安装phpredis扩展
    python练习:异常
  • 原文地址:https://www.cnblogs.com/houkai/p/10333164.html
Copyright © 2011-2022 走看看