zoukankan      html  css  js  c++  java
  • tensorflow2 基础操作

    数据的载体:

    Python中 list 

    Numpy中 np.array 

    Tensorflow中 tf.Tensor 

    tensor  的类型:

    int  float double 

    bool

    string 

    import tensorflow as tf
    
    
    tf.constant(1) # 普通的一个 int tensor 一个常量
    tf.constant(1.) # 普通的一个 float tensor 一个常量
    tf.constant(1.,dtype=tf.double) # 普通的一个 float tensor 一个常量
    
    
    tf.constant([True,False]) # bool
    
    tf.constant('hello world') # string
    View Code

    创建tensor :

    tensor 的索引和切片:

    维度变换:

    BroadCast:

    数学运算:

    +-*/ 

    **,pow ,square 

    sqrt 

    // ,%

    exp ,log

    @(矩阵乘),matmul

    linear,layer 

    前向传播:

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import datasets
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    (x,y),_ = datasets.mnist.load_data()
    
    x = tf.convert_to_tensor(x,dtype=tf.float32)/255.
    y = tf.convert_to_tensor(y,dtype=tf.int32)
    
    # print(x.shape,y.shape)
    # print(tf.reduce_min(x),tf.reduce_max(x))
    # print(tf.reduce_min(y),tf.reduce_max(y))
    
    # 构建数据集 一次取128
    train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
    train_iter = iter(train_db)
    sample = next(train_iter)
    print("batch:",sample[0].shape,sample[1].shape)   # batch: (128, 28, 28) (128,)
    
    # 初始化 w 和 b
    w1 = tf.random.truncated_normal([784,256],stddev=0.1)
    w1 = tf.Variable(w1) # 包装下 w1 为了能自动求 梯度
    b1 = tf.zeros([256])
    b1 = tf.Variable(b1)
    w2 = tf.random.truncated_normal([256,128],stddev=0.1)
    w2 = tf.Variable(w2)
    b2 = tf.zeros([128])
    b2 = tf.Variable(b2)
    w3 = tf.random.truncated_normal([128,10],stddev=0.1)
    w3 = tf.Variable(w3)
    b3 = tf.zeros([10])
    b3 = tf.Variable(b3)
    
    
    lr = 1e-3 # 0.001
    for i in range(10): # 对整个数据集 迭代10次
        for step,(x,y) in enumerate(train_db):
            # x [128,28,28 ],y [128,]
    
            # x [128,28*28 ],y [128,]
            # h1 = x@w1 + b1
            # 需要将 x 进行维度变换
            x = tf.reshape(x,[-1,28*28])
    
            with tf.GradientTape() as tape:
    
                h1 = x@w1 + b1  # [b,784] -> [784,256]
                h1 = tf.nn.relu(h1)  # 进行非线性的转换
                h2 = h1@w2 + b2 # [784,256] -> [256,128]
                h2 = tf.nn.relu(h2)  # 进行非线性的转换
                out = h2@w3 + b3 # [256,128] -> [128,10]
    
    
                # 然后计算误差
                # out : [b,10]
                # y:[b]
                # 首先将 y 进行 one-hot
                y_onehot = tf.one_hot(y,depth=10)
    
                # 计算均方差
                # mse = mean(sum(y- out)^2)
                loss = tf.reduce_mean(tf.square(y_onehot-out))
    
            # 求解梯度  对 w1 b1 w2 b2 w3 b3 自动求解梯度
            grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
    
            # 更新 w,b
            # w1 = w1 - lr*w1_grad
            # w1 = w1 - lr * grads[0]
            # b1 = b1 - lr * grads[1]
            # w2 = w2 - lr * grads[2]
            # b2 = b2 - lr * grads[3]
            # w3 = w3 - lr * grads[4]
            # b3 = b3 - lr * grads[5]
    
            w1.assign_sub(lr*grads[0])
            b1.assign_sub(lr*grads[1])
            w2.assign_sub(lr*grads[2])
            b2.assign_sub(lr*grads[3])
            w3.assign_sub(lr*grads[4])
            b3.assign_sub(lr*grads[5])
    
            if step % 100 == 0:
                print(step,'loss:',float(loss))
    前向传播

    张量的合并与分割:

    tf.concat 

    tf.stack

    tf.unstack 

    tf.split :split 比 unstack打散更灵活,

    数据统计:

    tf.norm   范数

    tf.reduce_min /max 

    tf.argmax/argmin ,返回最值的索引

    tf.equal

    tf.unique 

    张量排序:

    Sort / argsort 

    Topk

    Top-5 Acc.

    填充与复制:

    pad

    tile 

    broacast_to 

    张量限幅:

    clip_by_value 

    relu 

    clip_by_norm

    gradient clipping

    高阶操作:

    where 

    scatter_nd

    meshgrid

    神经网络与全连接层:

    数据加载:

    keras.datasets

    tf.data.Dataset.from_tensor_slices

      shuffle

      map

      batch

      repeat

  • 相关阅读:
    Java对象序列化文件追加对象的问题,以及Java的读取多个对象的问题解决方法。
    解决chrome在docky上的图标模糊或不能锁定的问题
    获取表单中的输入内容、单选按钮、复选框的输入内容
    用idea写servlet文件
    get方法和post方法
    解决Only a type can be imported. com.mysql.jdbc.Connection resolves to a package的报错问题
    idea中如何配置tomcat
    JDBC中的PreparedStatement
    JDBC中的ResultSet
    JDBCl链接中Statement
  • 原文地址:https://www.cnblogs.com/zach0812/p/13142138.html
Copyright © 2011-2022 走看看