zoukankan      html  css  js  c++  java
  • tensorflow 测量工具,与自定义训练

     

    # 新建测量器
    m = tf.keras.metrics.Accuracy()
    # 写入测量器
    m.update_state([0,1,1],[0,1,2])
    # 读取统计信息
    m.result() # 准确率为0.66
    # 清除
    m.reset_states()
    acc_meter = tf.keras.metrics.Accuracy()
    loss_meter = tf.keras.metrics.Mean() # 求平均loss
    op = tf.keras.optimizers.Adam(0.01)
    import datetime
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = "logs/"+current_time
    summary_writer = tf.summary.create_file_writer(logdir)
    for epoch in range(10):
       for step,(x,y) in enumerate(train_data):
           with tf.GradientTape() as tape:
               loss = tf.losses.categorical_crossentropy(y,model(x))
               loss_meter.update_state(loss) # 准确率
           grads = tape.gradient(loss,model.train_variables) # 求梯度
           op.apply_gradients(zip(grads,model.train_variables)) # 更新梯度 w = w - delta
           
           with summary_writer.as_default()
               tf.summary.scalar(name="loss",data=loss_meter.result().numpy(),step=xxxx)
           print(epoch,step,loss,loss_meter.result().numpy())   # numpy() 将tensor转化为变量
           loss_meter.reset_states()
       
       for step,(x,y) in enumerate(test_data):
           out = model(x)
           pred = tf.argmax(out,axis=-1)
           pred = tf.cast(pred,dtype=tf.int32)
           y = tf.cast(tf.argmax(y,axis=-1),dtype=tf.int32)
           acc_meter.update_state(y,pred)
       with summary_writer.as_default()
           tf.summary.scalar(name="acc",data=acc_meter.result().numpy(),step=xxxx)    
       print(epoch,acc_meter.result().numpy())
       acc_meter.reset_states()

     

  • 相关阅读:
    CSS浮动(float、clear)通俗讲解
    JAVA 类的加载
    数据库操作 delete和truncate的区别
    正则表达式 匹配相同数字
    Oracle EBS OM 取消订单
    Oracle EBS OM 取消订单行
    Oracle EBS OM 已存在的OM订单增加物料
    Oracle EBS OM 创建订单
    Oracle EBS INV 创建物料搬运单头
    Oracle EBS INV 创建物料搬运单
  • 原文地址:https://www.cnblogs.com/Dean0731/p/12831518.html
Copyright © 2011-2022 走看看