zoukankan      html  css  js  c++  java
  • tensorflow入门:Softmax Classication

    Softmax

    Softmax用于多元分类,同logistic regression一样使用cross entropy作为损失函数,其原理不再赘述。

    另外,多元分类中我们使用one-hot编码来表示种类。
    例:A,B,C三种类别的物体表示为[1, 0, 0][0, 1, 0][0, 0, 1],这种表示方式是为了矩阵计算上的便利。

    tensorflow实现

    import tensorflow as tf
    import numpy as np
    
    def convert_to_one_hot(Y, C):
        Y = np.eye(C)[Y.reshape(-1)]
        return Y
    
    # traing data
    x_data = [[1, 2, 1, 1], [2, 1, 3, 2], [3, 1, 3, 4], [4, 1, 5, 5], [1, 7, 5, 5], 
              [1, 2, 5, 6], [1, 6, 6, 6], [1, 7, 7, 7]]
    y_data = [[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0]]
    
    X = tf.placeholder("float", [None, 4])
    Y = tf.placeholder("float", [None, 3])
    
    # number of classes
    n_class = 3
    
    # define hyperparameter
    W = tf.Variable(tf.random_normal([4, n_class]), name="weight")
    b = tf.Variable(tf.random_normal([n_class]), name="bias")
    
    # define hypothesis using the built_in softmax
    # softmax = tf.exp(logits) / tf.reduce_mean(tf.exp(logits), dim)
    hypothesis = tf.nn.softmax(tf.matmul(X, W) + b)
    
    # cross entropy loss
    cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))
    
    # specify optimizer method
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        
        for step in range(2001):
            _, cost_val = sess.run([optimizer, cost], feed_dict={X: x_data, Y: y_data})
            if step % 200 == 0:
                print(step, cost_val)
        
        # test by making some prediction
        a = sess.run(hypothesis, feed_dict={X: [[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]]})
        p = sess.run(tf.argmax(a, 1))
    
        p_one_hot = convert_to_one_hot(p, n_class)
        
        print()
        print(a)
        print()
        print(p)
        print()
        print(p_one_hot)
    
    0 5.5841656
    200 0.4481054
    400 0.36178762
    600 0.28431925
    800 0.23740968
    1000 0.21427636
    1200 0.19528978
    1400 0.17937106
    1600 0.16581444
    1800 0.15412535
    2000 0.1439426
    
    [[2.0620492e-03 9.9792922e-01 8.6969685e-06]
     [9.0453833e-01 8.4767073e-02 1.0694573e-02]
     [5.9199765e-09 2.7693674e-04 9.9972302e-01]]
    
    [1 0 2]
    
    [[0. 1. 0.]
     [1. 0. 0.]
     [0. 0. 1.]]
    

    因为是个简单的例子,损失函数不断下降。

  • 相关阅读:
    linux上实现jmeter分布式压力测试(转)
    The more,the better。
    DP_括号匹配序列问题
    MySQL基础语句
    大端模式和小端模式
    C++:bitset用法
    TCP三次握手和四次握手
    静态库与动态库
    DP_最长公共子序列/动规入门
    二维数组和指针
  • 原文地址:https://www.cnblogs.com/wanghongze95/p/13842486.html
Copyright © 2011-2022 走看看