zoukankan      html  css  js  c++  java
  • tensorflow2.0——tensorboard与预测代码相结合

    实时的显示相关数据的图

    import tensorflow as tf
    import datetime
    
    def preporocess(x,y):
        x = tf.cast(x,dtype=tf.float32) / 255
        x = tf.reshape(x,(-1,28 *28))                   #   铺平
        x = tf.squeeze(x,axis=0)
        # print('里面x.shape:',x.shape)
        y = tf.cast(y,dtype=tf.int32)
        return x,y
    
    def main():
        #   加载手写数字数据
        mnist = tf.keras.datasets.mnist
        (train_x, train_y), (test_x, test_y) = mnist.load_data()
        #   处理数据
            #   训练数据
        db = tf.data.Dataset.from_tensor_slices((train_x,train_y))    #   将x,y分成一一对应的元组
        db = db.map(preporocess)                                    #   执行预处理函数
        db = db.shuffle(60000).batch(2000)                          #   打乱加分组
            #   测试数据
        db_test = tf.data.Dataset.from_tensor_slices((test_x,test_y))
        db_test = db_test.map(preporocess)
        db_test = db_test.shuffle(10000).batch(10000)
        #   设置超参
        iter_num = 2000                                             #   迭代次数
        lr = 0.01                                                   #   学习率
        #   定义模型器和优化器
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(256,activation='relu'),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(32, activation='relu'),
            tf.keras.layers.Dense(10)
        ])
        # model.build(input_shape=[None,28*28])                     #   事先查看网络结构
        # model.summary()
        #   优化器
        # optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    
        #   迭代训练
        db_iter = iter(db)
        for i in range(iter_num):
            for step,(x,y) in enumerate(db):
                with tf.GradientTape() as tape:
                    logits = model(x)
                    y_onehot = tf.one_hot(y,depth=10)
                    # loss = tf.reduce_mean(tf.losses.MSE(y_onehot,logits))                                         #   差平方损失
                    loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True))     #   交叉熵损失
                grads = tape.gradient(loss,model.trainable_variables)                                               #   梯度
                grads,_ = tf.clip_by_global_norm(grads,15)                                                          #   梯度限幅
                optimizer.apply_gradients(zip(grads,model.trainable_variables))                                     #   更新参数
                if step % 10 == 0:
                    #   将数据写入log文件
                    with summary_writer.as_default():
                        tf.summary.scalar('loss', float(loss), step=step)
                    pass
                    # print('i:{} , step:{} , loss:{} '.format(i,step,loss))
            #   计算测试集准确率
            for (x,y) in db_test:
                logits = model(x)
                out = tf.nn.softmax(logits,axis=1)
                pre = tf.argmax(out,axis=1)
                pre = tf.cast(pre,dtype=tf.int32)
                print(pre.shape,y.shape)
                acc  = tf.equal(pre,y)
                acc = tf.cast(acc,dtype=tf.int32)
                acc = tf.reduce_mean(tf.cast(acc,dtype=tf.float32))
                print('i:{}'.format(i))
                print('acc:{}'.format(acc))
                #   ************************** 将数据写入log文件 ***********************************
                with summary_writer.as_default():
                    tf.summary.scalar('acc', float(acc), step=i)
    
    if __name__ == '__main__':
        #   ***************************** tensorboard文件处理 *******************************
        current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')  # 当前时间
        # print('当前时间:',current_time)
        log_dir = 'tb_data/logs/' + current_time  # 以当前时间作为log文件名
        summary_writer = tf.summary.create_file_writer(log_dir)  # 创建log文件
        main()

  • 相关阅读:
    MYSQL查询练习 1
    Mysql语句练习记录
    博客园背景样式修改
    MYSQL安装与卸载(一)
    IDEA 使用与总结
    解决layui弹窗提示刷新页面一闪而逝的问题
    System.Xml.XmlException: 分析 EntityName 时出错
    PS快速把倾斜的图片调正
    iis添加asp.net网站,访问提示:由于扩展配置问题而无法提供您请求的页面。如果该页面是脚本,请添加处理程序。如果应下载文件,请添加 MIME 映射
    c# 递归查找父类的子类
  • 原文地址:https://www.cnblogs.com/cxhzy/p/13536636.html
Copyright © 2011-2022 走看看