zoukankan      html  css  js  c++  java
  • Tensorflow

    XDeepFM的CIN中第一层实现需要使两个二维矩阵相乘得到一个三维张量,于是来复习下split函数(需要用到):
    首先看下函数原理:

    tf.split(
        value,
        num_or_size_splits,
        axis=0,
        num=None,
        name='split'
    )

    这个函数是用来切割张量的:输入切割的张量和参数,返回切割的结果。
    value传入的就是需要切割的张量,axis的数值代表切割哪个维度。
    这个函数有两种切割的方式:

    以三个维度的张量为例,比如说一个20 * 30 * 40的张量my_tensor,就如同一个长20厘米宽30厘米高40厘米的蛋糕,每立方厘米都是一个分量。

    有两种切割方式:
    1. 如果num_or_size_splits传入的是一个整数,这个整数代表这个张量最后会被切成几个小张量。此时,传入axis的数值就代表切割哪个维度(从0开始计数)。调用tf.split(my_tensor, 2,0)返回两个10 * 30 * 40的小张量。
    2. 如果num_or_size_splits传入的是一个向量,那么向量有几个分量就分成几份,切割的维度还是由axis决定。比如调用tf.split(my_tensor, [10, 5, 25], 2),则返回三个张量分别大小为 20 * 30 * 10、20 * 30 * 5、20 * 30 * 25。很显然,传入的这个向量各个分量加和必须等于axis所指示原张量维度的大小 (10 + 5 + 25 = 40)。

    一个实例:

    import tensorflow as tf
    import numpy as np
    
    arr1 = tf.convert_to_tensor(np.arange(1,25).reshape(2,4,3),dtype=tf.int32)
    
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        split_arr1 = tf.split(arr1,[1,1,1],2) # 切割成2个2*4*1的张量
       print(sess.run(split_arr1)

    可以看到原来的2*4*3的张量被切割为了3个2*4*1的张量

    Reference:

    https://blog.csdn.net/SangrealLilith/article/details/80272346

  • 相关阅读:
    如何结合后台数据库 启动vue项目
    nodejs卸载安装
    mysql安装过程
    VUE-cli脚手架
    css伪类
    element中遇到的表格问题总结
    小程序折叠面板的功能
    vue学习中遇到的onchange、push、splice、forEach方法使用
    vscode好用的扩展及常用的快捷键
    Flutter之SliverAppBar
  • 原文地址:https://www.cnblogs.com/Jesee/p/11277868.html
Copyright © 2011-2022 走看看