zoukankan      html  css  js  c++  java
  • tensorflow(十七):高阶操作:tf.where(),tf.scatter_nd(), tf.meshgrid()

    一、tf.where()对tensor操作返回一系列坐标

     

     

     二、tf.scatter_nd()根据坐标有目的的进行更新

     

     三、tf.meshgrid()对tensor如果要画一个3D的图片,要生成一个3D的坐标轴

     

     

     

     

     

     四、实战

    import tensorflow as tf
    
    import matplotlib.pyplot as plt
    
    def func(x):
        """
    
        :param x: [b,2] 2的前面部分就是x坐标,后面就是y坐标。
        :return:
        """
        z=tf.math.sin(x[...,0]) + tf.math.sin(x[...,1]) #...表示取所有维度的0,前面所有点的x部分,后者所有点y部分。
        return z
    
    x = tf.linspace(0., 2*3.14, 500)
    y = tf.linspace(0., 2*3.14, 500)
    # [50, 50]
    point_x, point_y = tf.meshgrid(x, y)
    #[50, 50, 2]
    points = tf.stack([point_x, point_y], axis=2)
    # points = tf.reshape(points, [-1, 2])
    
    print('points: ',points.shape)
    z = func(points)
    print("z: ", z.shape)
    
    plt.figure('plot 2d func value')
    plt.imshow(z, origin='lower', interpolation='none')
    plt.colorbar()
    
    plt.figure('plot 2d func contour')
    plt.contour(point_x, point_y, z)
    plt.colorbar()
    plt.show()
  • 相关阅读:
    自定义一个运行时异常
    对象的知识点正确解释
    decimal模块
    B+树
    Web框架系列之Tornado
    初识git
    Mysql表的操作
    MySQl创建用户和授权
    MySql安装和基本管理
    为什么用Mysql?
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14612282.html
Copyright © 2011-2022 走看看