zoukankan      html  css  js  c++  java
  • Tensorflow遇到的问题

    问题1、自定义loss function,y_true shape多一个维度

    def nce_loss(y_true, y_pred):
      y_true = tf.reshape(y_true, [-1])
      y_true = tf.linalg.diag(y_true)
      ret = tf.keras.metrics.categorical_crossentropy(y_true, y_pred, from_logits=False)
      ret = tf.reduce_mean(ret)
      return ret
    

    问题分析:

    如上面代码所示,tf.keras相关API,在自定义loss function时,执行model.fit方法时报错,大致意思:计算tf.keras.metrics.categorical_crossentropy时输入数据的shape不符合预期,y_true的shape是[None, 1, 1];经过手动debug,发现y_true的shape变成了[None,1],正常应该是[None,]。查了一些资料发现,出现这个问题的原因是,y_true的shape默认是与y_pred一致的,并且无法单独指定,导致在经过tf.linalg.diag函数时shape变成了[None, 1, 1]

    解决方案

    添加代码y_true = tf.reshape(y_true, [-1]),强致将shape 转成一维。

  • 相关阅读:
    C++---const
    qt--textEdit多行文本编辑框
    qt--QByteArray字节数组
    qt5--拖放
    qt5--自定义事件与事件的发送
    qt5--键盘事件
    qt5--鼠标事件
    qt5-事件过滤器
    qt5-event事件的传递
    qt-事件的接受和忽略
  • 原文地址:https://www.cnblogs.com/hwyang/p/15543170.html
Copyright © 2011-2022 走看看