zoukankan      html  css  js  c++  java
  • tf.nn.nce_loss

    def nce_loss(weights,biases,inputs,labels,num_sampled,num_classes,num_true=1,sampled_values=None,remove_accidental_hits=False,partition_strategy="mod",name="nce_loss")

    假设nce_loss之前的输入数据是K维的,一共有N个类

    weight.shape = (N, K)

    bias.shape = (N)

    inputs.shape = (batch_size, K)

    labels.shape = (batch_size, num_true)

    num_true :实际的正样本个数

    num_sampled: 采样出多少个负样本

    num_classes = N

    sampled_values: 采样出的负样本,如果是None,就会用不同的sampler去采样。

    remove_accidental_hits: 如果采样时不小心采样到的负样本刚好是正样本,要不要干掉。

    partition_strategy:对weights进行embedding_lookup时并行查表时的策略。TF的embeding_lookup是在CPU里实现的,这里需要考虑多线程查表时的锁的问题。

    nce_loss的实现逻辑如下: 

    _compute_sampled_logits: 通过这个函数计算出正样本和采样出的负样本对应的output和label

    sigmoid_cross_entropy_with_logits: 通过 sigmoid cross entropy来计算output和label的loss,从而进行反向传播。

    这个函数把最后的问题转化为了num_sampled+num_real个两类分类问题,然后每个分类问题用了交叉熵的损伤函数,也就是logistic regression常用的损失函数。

    TF里还提供了一个softmax_cross_entropy_with_logits的函数,和这个有所区别。

  • 相关阅读:
    C#网络安全的一个不错的站点
    SP2已经发布,明天MS要发布一个Exchange的package
    Python学习足迹(3)
    用例子来彻底搞明白Virtual 和 非 virtual(C#)
    概述Web编程的安全极其防护措施(主要针对PHP,PERL)[]
    Java序列化
    Mybatis缓存及原理
    代理模式
    Spring的依赖注入
    Mybatis运行流程
  • 原文地址:https://www.cnblogs.com/wzdLY/p/10066620.html
Copyright © 2011-2022 走看看