zoukankan      html  css  js  c++  java
  • TensorFlow笔记二:线性回归预测(Linear Regression)

    代码:

    import tensorflow as tf
    import numpy as np
    import xlrd
    import matplotlib.pyplot as plt
    
    DATA_FILE = 'fire_theft.xls'
    
    # 1.read from data file
    book=xlrd.open_workbook(DATA_FILE,encoding_override="utf-8")
    sheet=book.sheet_by_index(0)
    data=np.asarray([sheet.row_values(i) for i in range(1,sheet.nrows)])
    n_samples=sheet.nrows-1
    
    # 2.creat placeholders for input x(number of file) and label Y(number of theft)
    X=tf.placeholder(tf.float32,name='X')
    Y=tf.placeholder(tf.float32,name='Y')
    
    # 3.creat weight and bias ,init to 0
    w=tf.Variable(0.0,name='weights')
    b=tf.Variable(0.0,name='bias')
    
    # 4.build model to predict Y
    Y_predicted = X* w +b
    
    # 5.use square error as the lose function
    loss=tf.square(Y-Y_predicted,name='loss')
    
    # 6.using gradient descent with learning rate 0.01 to minimize loss
    optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)
    
    with tf.Session() as sess:
        # 7.init necessary variables (w and b)
        sess.run(tf.global_variables_initializer())
        
        writer=tf.summary.FileWriter('./my_graph/linear_reg',sess.graph)
        
        # 8.train the model 100 times
        for i in range(100):
            total_loss =0
            for x,y in data:
                #session runs train_op and fetch values of loss
                _,l=sess.run([optimizer,loss],feed_dict={X:x,Y:y})
                total_loss +=l
            print('Epoch {0}:{1}'.format(i,total_loss/n_samples))
        
        # close the writer
        writer.close()
                
        # 9.output the value of w and b
        w_value,b_value=sess.run([w,b])
        
    # plot the result
    X,Y=data.T[0],data.T[1]
    plt.plot(X,Y,'bo',label='Real data')
    plt.plot(X,X*w_value+b_value,'r',label='Predected data')
    plt.legend()
    plt.show()
    fire_theft.xls

                                                             

    图例:

    TFboard:                                  tensorboard --logdir="./my_graph/linear_reg" --port 6006

    
    
    
  • 相关阅读:
    DBUtils温习2
    DBUtils温习1
    C3P0连接池温习1
    JDBC复习2
    JDBC复习1
    Spring的AOP基于AspectJ的注解方式开发3
    Spring的AOP基于AspectJ的注解方式开发2
    Spring的AOP基于AspectJ的注解方式开发1
    高血压认知3
    pandas cookbook
  • 原文地址:https://www.cnblogs.com/dzzy/p/9873721.html
Copyright © 2011-2022 走看看