zoukankan      html  css  js  c++  java
  • tensorflow 测量工具,与自定义训练

     

    # 新建测量器
    m = tf.keras.metrics.Accuracy()
    # 写入测量器
    m.update_state([0,1,1],[0,1,2])
    # 读取统计信息
    m.result() # 准确率为0.66
    # 清除
    m.reset_states()
    acc_meter = tf.keras.metrics.Accuracy()
    loss_meter = tf.keras.metrics.Mean() # 求平均loss
    op = tf.keras.optimizers.Adam(0.01)
    import datetime
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = "logs/"+current_time
    summary_writer = tf.summary.create_file_writer(logdir)
    for epoch in range(10):
       for step,(x,y) in enumerate(train_data):
           with tf.GradientTape() as tape:
               loss = tf.losses.categorical_crossentropy(y,model(x))
               loss_meter.update_state(loss) # 准确率
           grads = tape.gradient(loss,model.train_variables) # 求梯度
           op.apply_gradients(zip(grads,model.train_variables)) # 更新梯度 w = w - delta
           
           with summary_writer.as_default()
               tf.summary.scalar(name="loss",data=loss_meter.result().numpy(),step=xxxx)
           print(epoch,step,loss,loss_meter.result().numpy())   # numpy() 将tensor转化为变量
           loss_meter.reset_states()
       
       for step,(x,y) in enumerate(test_data):
           out = model(x)
           pred = tf.argmax(out,axis=-1)
           pred = tf.cast(pred,dtype=tf.int32)
           y = tf.cast(tf.argmax(y,axis=-1),dtype=tf.int32)
           acc_meter.update_state(y,pred)
       with summary_writer.as_default()
           tf.summary.scalar(name="acc",data=acc_meter.result().numpy(),step=xxxx)    
       print(epoch,acc_meter.result().numpy())
       acc_meter.reset_states()

     

  • 相关阅读:
    伯克利推出世界最快的KVS数据库Anna:秒杀Redis和Cassandra
    不要什么都学-打造自己的差异化价值
    gitlab markdown支持页面内跳转
    技术人员怎样提升对业务的理解
    为什么HDFS的副本数通常选择3?
    MySQL++简单使用记录.md
    log4cpp安装使用
    log4cxx安装使用
    epoll使用总结
    工作方法-scrum+番茄工作法
  • 原文地址:https://www.cnblogs.com/Dean0731/p/12831518.html
Copyright © 2011-2022 走看看