zoukankan      html  css  js  c++  java
  • 高阶操作

    Outline

    • where

    • scatter_nd

    • meshgrid

    Where

    where(tensor)

    • where获得以下表格中True的位置
    123
    True False False
    False True False
    False False True
    import tensorflow as tf
    
    a = tf.random.normal([3, 3])
    a
    
    <tf.Tensor: id=11, shape=(3, 3), dtype=float32, numpy=
    array([[-0.02527909, -0.09084062,  0.34427297],
           [-0.45223615,  1.1085868 , -1.9480664 ],
           [-2.3520288 , -1.8698558 , -0.30862013]], dtype=float32)>
    
    mask = a > 0
    mask
    
    <tf.Tensor: id=16, shape=(3, 3), dtype=bool, numpy=
    array([[False, False,  True],
           [False,  True, False],
           [False, False, False]])>
    
    # 为True元素的值
    tf.boolean_mask(a, mask)
    
    <tf.Tensor: id=44, shape=(2,), dtype=float32, numpy=array([0.34427297, 1.1085868 ], dtype=float32)>
    
    # 为True元素,即>0的元素的索引
    indices = tf.where(mask)
    indices
    
    <tf.Tensor: id=47, shape=(2, 2), dtype=int64, numpy=
    array([[0, 2],
           [1, 1]])>
    
    # 取回>0的值
    tf.gather_nd(a, indices)
    
    <tf.Tensor: id=49, shape=(2,), dtype=float32, numpy=array([0.34427297, 1.1085868 ], dtype=float32)>
    

    where(cond,A,B)

    mask
    
    <tf.Tensor: id=16, shape=(3, 3), dtype=bool, numpy=
    array([[False, False,  True],
           [False,  True, False],
           [False, False, False]])>
    
    A = tf.ones([3, 3])
    B = tf.zeros([3, 3])
    
    # True的元素会从A中选值,False的元素会从B中选值
    tf.where(mask, A, B)
    
    <tf.Tensor: id=61, shape=(3, 3), dtype=float32, numpy=
    array([[0., 0., 1.],
           [0., 1., 0.],
           [0., 0., 0.]], dtype=float32)>
    

    scatter_nd

    • tf.scatter_nd(
    • indices,
    • updates,
    • shape)

    一维

    13-高阶操作-scatter_nd.jpg

    indices = tf.constant([[4], [3], [1], [7]])
    updates = tf.constant([9, 10, 11, 12])
    shape = tf.constant([8])
    
    # 把updates按照indices的索引放在底板shape上
    tf.scatter_nd(indices, updates, shape)
    
    <tf.Tensor: id=71, shape=(8,), dtype=int32, numpy=array([ 0, 11,  0, 10,  9,  0,  0, 12], dtype=int32)>
    

    二维

    13-高阶操作-scatter_nd2.jpg

    indices = tf.constant([[0], [2]])
    updates = tf.constant([
        [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
        [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
    ])
    updates.shape
    
    TensorShape([2, 4, 4])
    
    shape = tf.constant([4, 4, 4])
    
    tf.scatter_nd(indices, updates, shape)
    
    <tf.Tensor: id=76, shape=(4, 4, 4), dtype=int32, numpy=
    array([[[5, 5, 5, 5],
            [6, 6, 6, 6],
            [7, 7, 7, 7],
            [8, 8, 8, 8]],
    
           [[0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0]],
    
           [[5, 5, 5, 5],
            [6, 6, 6, 6],
            [7, 7, 7, 7],
            [8, 8, 8, 8]],
    
           [[0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0]]], dtype=int32)>
    

    meshgrid

    • [-2,-2]
    • [-1,-2]
    • [0,-2]
    • [-2,-2]
    • [-1,-1]
    • ...
    • [2,2]

    13-高级操作-meshgrid.jpg

    Points

    • [y,x,w]

      • [5,5,2]
    • [N,2]

    14-高阶操作-Points.jpg

    numpy实现

    import numpy as np
    
    points = []
    
    for y in np.linspace(-2, 2, 5):
        for x in np.linspace(-2, 2, 5):
            points.append([x, y])
    
    np.array(points)
    
    array([[-2., -2.],
           [-1., -2.],
           [ 0., -2.],
           [ 1., -2.],
           [ 2., -2.],
           [-2., -1.],
           [-1., -1.],
           [ 0., -1.],
           [ 1., -1.],
           [ 2., -1.],
           [-2.,  0.],
           [-1.,  0.],
           [ 0.,  0.],
           [ 1.,  0.],
           [ 2.,  0.],
           [-2.,  1.],
           [-1.,  1.],
           [ 0.,  1.],
           [ 1.,  1.],
           [ 2.,  1.],
           [-2.,  2.],
           [-1.,  2.],
           [ 0.,  2.],
           [ 1.,  2.],
           [ 2.,  2.]])
    

    tensorflow2实现

    y = tf.linspace(-2., 2, 5)
    y
    
    <tf.Tensor: id=81, shape=(5,), dtype=float32, numpy=array([-2., -1.,  0.,  1.,  2.], dtype=float32)>
    
    x = tf.linspace(-2., 2, 5)
    x
    
    <tf.Tensor: id=86, shape=(5,), dtype=float32, numpy=array([-2., -1.,  0.,  1.,  2.], dtype=float32)>
    
    points_x, points_y = tf.meshgrid(x, y)
    points_x.shape
    
    TensorShape([5, 5])
    
    points_x
    
    <tf.Tensor: id=130, shape=(5, 5), dtype=float32, numpy=
    array([[-2., -1.,  0.,  1.,  2.],
           [-2., -1.,  0.,  1.,  2.],
           [-2., -1.,  0.,  1.,  2.],
           [-2., -1.,  0.,  1.,  2.],
           [-2., -1.,  0.,  1.,  2.]], dtype=float32)>
    
    points_y
    
    <tf.Tensor: id=131, shape=(5, 5), dtype=float32, numpy=
    array([[-2., -2., -2., -2., -2.],
           [-1., -1., -1., -1., -1.],
           [ 0.,  0.,  0.,  0.,  0.],
           [ 1.,  1.,  1.,  1.,  1.],
           [ 2.,  2.,  2.,  2.,  2.]], dtype=float32)>
    
    points = tf.stack([points_x, points_y], axis=2)
    points
    
    <tf.Tensor: id=135, shape=(5, 5, 2), dtype=float32, numpy=
    array([[[-2., -2.],
            [-1., -2.],
            [ 0., -2.],
            [ 1., -2.],
            [ 2., -2.]],
    
           [[-2., -1.],
            [-1., -1.],
            [ 0., -1.],
            [ 1., -1.],
            [ 2., -1.]],
    
           [[-2.,  0.],
            [-1.,  0.],
            [ 0.,  0.],
            [ 1.,  0.],
            [ 2.,  0.]],
    
           [[-2.,  1.],
            [-1.,  1.],
            [ 0.,  1.],
            [ 1.,  1.],
            [ 2.,  1.]],
    
           [[-2.,  2.],
            [-1.,  2.],
            [ 0.,  2.],
            [ 1.,  2.],
            [ 2.,  2.]]], dtype=float32)>
    

    14-高阶操作-等高线图.jpg

  • 相关阅读:
    scss文件报错处理 (报错信息Invalid CSS after "v": expected 1 selector or at-rule, was 'var api = require)
    vue-countdown组件
    vue dayjs in ./node_modules/babel-loader/lib!./node_modules/vue-loader/lib/selector.js
    vue you can run: npm install --save !!vue-styles-loader!css-loader?
    解决npm报错:Module build failed: TypeError: this.getResolve is not a function
    【JVM从小白学成大佬】3.深入解析强引用、软引用、弱引用、幻象引用
    【JVM从小白学成大佬】2.Java虚拟机运行时数据区
    【JVM从小白学成大佬】开篇
    【必知必会】深入解析强引用、软引用、弱引用、幻象引用
    【不做标题党,只做纯干货】HashMap在jdk1.7和1.8中的实现
  • 原文地址:https://www.cnblogs.com/abdm-989/p/14123255.html
Copyright © 2011-2022 走看看