zoukankan      html  css  js  c++  java
  • attention机制



    batch_size = 60
    hidden_size_att = 200
    attention_size = 30
    lensen = 15
    outputs = tf.placeholder(tf.float32,[lensen,batch_size, hidden_size_att] )
    inputs = tf.transpose(outputs, [1, 0, 2])
    #inputs: (60, 15, 200) 
    w_att = tf.Variable(tf.random_normal([hidden_size_att,attention_size], stddev=0.1))
    #w_att: (200, 30)
    b_att = tf.Variable(tf.random_normal([attention_size], stddev=0.1))
    #b_att: (30,)         
    H = tf.matmul(tf.reshape(outputs,[-1,hidden_size_att]),w_att) + tf.reshape(b_att,[1,-1])
    #H: (900, 30)
    M = tf.tanh(H)
    #M: (900, 30)       
    v_att = tf.Variable(tf.random_normal([attention_size], stddev=0.1))
    #v_att: (30,)
    A = tf.nn.softmax(tf.matmul(M, tf.reshape(v_att, [-1,1])))#{{batch*lensen,1}}
    #A: (900, 1)
    R = tf.reduce_sum(inputs * tf.reshape(A,[-1,lensen,1]),1)#(batch,hidden)?
    #R: (60, 200)


    hidden_size = 100  
    batch_size = 60
    max_time = 15
    att_size= 30
    inputs = tf.placeholder(tf.float32,[batch_size, max_time, hidden_size * 2] )
    #inputs: (60, 15, 200)  
    u_context = tf.Variable(tf.truncated_normal([att_size]), name='u_context')
    #u_context: (30,)
    h = tf.contrib.layers.fully_connected(inputs, att_size , activation_fn=tf.nn.tanh)
    #h: (60, 15, 30) 
    alpha = tf.nn.softmax(tf.reduce_sum(tf.multiply(h, u_context), axis=2, keep_dims=True), dim=1)
    #alpha: (60, 15, 1)
    atten_output = tf.reduce_sum(tf.multiply(inputs, alpha), axis=1)
    #atten_output: (60, 200)

    假设输入(60, 15, 200) 的数据是指一个batch中60个句子,每个句子有15个词,每个词向量200维,

    那么(60, 15, 1)的alpha,相当于每个字的权重。

  • 相关阅读:
    git pull遇到错误:error: Your local changes to the following files would be overwritten by merge:
    angular 过滤器(日期转换,时间转换,数据转换等)
    js 毫秒转天时分秒
    使用Vue-Router 2实现路由功能
    Vue 2.5 发布了:15篇前端热文回看
    es6 语法 (模块化)
    es6 语法 (Decorator)
    es6 语法 (Generator)
    js 判断当前是什么浏览器
  • 原文地址:https://www.cnblogs.com/chenyaling/p/9453831.html
Copyright © 2011-2022 走看看