zoukankan      html  css  js  c++  java
  • paddle lookalike 实现(paddle reshape)

    需求:算 batch*1*32 与 batch*10*32 attention

     1     def local_attention_unit(self, target_user, user_seeds):
     2         user_target_reshape = fluid.layers.unsqueeze(target_user,axes=[1]) 
     3         user_seeds_reshape = fluid.layers.reshape(user_seeds, shape=[-1, 10, 32])
     4         out = fluid.layers.matmul(user_target_reshape, user_seeds_reshape, transpose_y=True)  # -1,1,10
     5         out = fluid.layers.softmax(out) #-1,1,10
     6         out = fluid.layers.matmul(out,user_seeds_reshape) #-1,1,32
     7         out = fluid.layers.reduce_sum(out, dim=1)  # batch_size * emb_size
     8         return out
     9 
    10 self.lookalike_cluster = fluid.layers.data(name="lookalike_cluster", shape=[-1,320], dtype="float32", lod_level=0, append_batch_size=False)
    11 
    12  self.user_gcf_vec = fluid.layers.data(name="user_gcf_vec", shape=[-1,32], dtype="float32", lod_level=0, append_batch_size=False)
    13 
    14 attention_unit_out = self.local_attention_unit( self.user_gcf_vec, self.lookalike_cluster)

    reshape 具体逻辑:

     1 import paddle
     2 import paddle.fluid as fluid
     3 import numpy as np
     4 #2*6
     5 data_x = np.array([[1.0, 1.0, 1.0,3.0, 3.0, 3.0],[1.0, 2.0, 1.0,4.0, 3.0, 5.0]])
     6 print data_x
     7 with fluid.dygraph.guard():
     8     x = fluid.dygraph.to_variable(data_x)
     9     out_z2 = fluid.layers.reshape(x, shape=[-1,2,3])
    10     print(out_z2.numpy())

  • 相关阅读:
    C段/旁站,子域名爆破的概念
    Linux USB Printer Gadget Driver
    Multifunction Composite Gadget
    PXA2xx SPI on SSP driver HOWTO
    SPI用户空间API
    Linux内核SPI支持概述
    Industrial I/O
    I2C设备驱动程序从用户空间绑定控制(旧内核)
    I2C 10-bit 地址
    Slave I2C
  • 原文地址:https://www.cnblogs.com/zle1992/p/15245182.html
Copyright © 2011-2022 走看看