zoukankan      html  css  js  c++  java
  • tensorflow2.0——meter简单的loss和acc处理方式

    1.  创建meter

      

     2.  添加数据

        

    3.  展示结果

      

     4.  清除meter  

      

    以下代码是在前面随笔中代码的基础上添加的meter相关操作:

    • import tensorflow as tf
      import datetime
      
      def preporocess(x,y):
          x = tf.cast(x,dtype=tf.float32) / 255
          x = tf.reshape(x,(-1,28 *28))                   #   铺平
          x = tf.squeeze(x,axis=0)
          # print('里面x.shape:',x.shape)
          y = tf.cast(y,dtype=tf.int32)
          return x,y
      
      def main():
          #   加载手写数字数据
          mnist = tf.keras.datasets.mnist
          (train_x, train_y), (test_x, test_y) = mnist.load_data()
          #   处理数据
              #   训练数据
          db = tf.data.Dataset.from_tensor_slices((train_x,train_y))    #   将x,y分成一一对应的元组
          db = db.map(preporocess)                                    #   执行预处理函数
          db = db.shuffle(60000).batch(2000)                          #   打乱加分组
              #   测试数据
          db_test = tf.data.Dataset.from_tensor_slices((test_x,test_y))
          db_test = db_test.map(preporocess)
          db_test = db_test.shuffle(10000).batch(10000)
          #   设置超参
          iter_num = 2000                                             #   迭代次数
          lr = 0.01                                                   #   学习率
          #   定义模型器和优化器
          model = tf.keras.Sequential([
              tf.keras.layers.Dense(256,activation='relu'),
              tf.keras.layers.Dense(128, activation='relu'),
              tf.keras.layers.Dense(64, activation='relu'),
              tf.keras.layers.Dense(32, activation='relu'),
              tf.keras.layers.Dense(10)
          ])
          # model.build(input_shape=[None,28*28])                     #   事先查看网络结构
          # model.summary()
          #   优化器
          # optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
          optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
      
          #   创建meter存储loss和acc
          acc_meter = tf.keras.metrics.Accuracy()
          loss_meter = tf.keras.metrics.Mean()
      
          #   迭代训练
          for i in range(iter_num):
              for step,(x,y) in enumerate(db):
                  with tf.GradientTape() as tape:
                      logits = model(x)
                      y_onehot = tf.one_hot(y,depth=10)
                      # loss = tf.reduce_mean(tf.losses.MSE(y_onehot,logits))                                         #   差平方损失
                      loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True))     #   交叉熵损失
      
                      loss_meter.update_state(loss)                                                                   #   添加loss进meter
      
                  grads = tape.gradient(loss,model.trainable_variables)                                               #   梯度
                  grads,_ = tf.clip_by_global_norm(grads,15)                                                          #   梯度限幅
                  optimizer.apply_gradients(zip(grads,model.trainable_variables))                                     #   更新参数
                  #   tensorboard显示时写入文件的代码
                  # if step % 10 == 0:
                  #     #   将数据写入log文件
                  #     with summary_writer.as_default():
                  #         tf.summary.scalar('loss', float(loss), step=step)
                  #     pass
      
              #   计算测试集准确率
              for (x,y) in db_test:
                  logits = model(x)
                  out = tf.nn.softmax(logits,axis=1)
                  pre = tf.argmax(out,axis=1)
                  pre = tf.cast(pre,dtype=tf.int32)
                  #   调用meter接口求acc
                  acc_meter.update_state(y,pre)
                  print()
                  #   以下是自己编写的求acc的方法
                  # acc  = tf.equal(pre,y)
                  # acc = tf.cast(acc,dtype=tf.int32)
                  # acc = tf.reduce_mean(tf.cast(acc,dtype=tf.float32))
                  # print('i:{}'.format(i))
                  # print('acc:{}'.format(acc))
                  #   ************************** 将数据写入log文件 ***********************************
                  # with summary_writer.as_default():
                  #     tf.summary.scalar('acc', float(acc), step=i)
              print('loss_meter.result().numpy():', loss_meter.result().numpy())
              print('acc_meter.result().numpy():', acc_meter.result().numpy())
              loss_meter.reset_states()
              acc_meter.reset_states()
              print('第{}次迭代结束'.format(i))
      if __name__ == '__main__':
          #   ***************************** tensorboard文件处理 *******************************
          # current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')      # 当前时间
          # log_dir = 'tb_data/logs/' + current_time                              # 以当前时间作为log文件名
          # summary_writer = tf.summary.create_file_writer(log_dir)               # 创建log文件
          main()
  • 相关阅读:
    NLPIR的语义分析系统
    [译] 12步轻松搞定python装饰器
    python实现爬取千万淘宝商品的方法_python_脚本之家
    Deep Learning(深度学习)学习笔记整理系列 | @Get社区
    那些年,曾经被我们误读的大数据
    值得关注的10个python语言博客
    淘宝的评论归纳是如何做到的?
    pycharm激活码
    Windows下配置Qt 5.8+opencv 3.1.0开发环境
    Ubuntu安装opencv3.1.0后配置环境变量
  • 原文地址:https://www.cnblogs.com/cxhzy/p/13661550.html
Copyright © 2011-2022 走看看