zoukankan      html  css  js  c++  java
  • 机器学习——TensorFlow训练Y=2*X

    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    
    def moving_average(a,w=10):
        if(len(a)<w):
            return a[:]
        return [val if idx <w else sum(a[idx-w:idx])/w for idx,val in enumerate(a)]
    X_train=np.linspace(-1,1,100)
    Y_train=2*X_train+np.random.randn(*X_train.shape)*0.3
    # 训练数据
    
    X=tf.placeholder('float')
    Y=tf.placeholder('float')
    # 占位符
    
    W=tf.Variable(tf.random_normal([1]),name='weight')
    b=tf.Variable(tf.zeros([1]),name='bias')
    # 定义权重和偏置
    
    z=tf.multiply(X,W)+b
    # 定义前向结构
    
    # 反向模型的搭建即反向优化
    loss=tf.reduce_mean(tf.square(Y-z))
    
    # 定义学习率:代表调整参数的速度,这个值一般小于一,这个值越大,表明调整幅度的速度越大,但不精确
    # 这个值越小,调整幅度越小,但是速度慢
    learn_rate=0.01
    
    # 定义优化器:GridientDescentOptimizer梯度下降算法
    optimizer=tf.train.GradientDescentOptimizer(learn_rate).minimize(loss)
    
    # 迭代训练模型,初始化所有变量
    init=tf.global_variables_initializer()
    
    # 定义训练次数
    training_epochs=20
    
    # 定义显示信息
    display_step=2
    
    with tf.Session() as sess:
        sess.run(init)
        plot_data={'batch_size':[],'loss_value':[]}
        for epoch in range(training_epochs):
            for x,y in zip(X_train,Y_train):
                sess.run(optimizer,feed_dict={X:x,Y:y})
            if epoch % display_step == 0:
                loss_value = sess.run(loss, feed_dict={X:x, Y:y})
                print('Epoch:', epoch + 1, 'Loss=', loss_value, 'w=',sess.run(W),'b=',sess.run(b))
                if not (loss == 'NA'):
                    plot_data['batch_size'].append(epoch)
                    plot_data['loss_value'].append(loss_value)
        print('Finished!')
        # 可视化模型
        plt.plot(X_train, Y_train, 'ro', label='Origin data')
        plt.plot(X_train, sess.run(W) * X_train + sess.run(b),label='FittedLine')
        plt.legend()
        plt.show()
        plot_data['avgloss']=moving_average(plot_data['loss_value'])
        plt.figure(1)
        plt.subplot(211)
        plt.plot(plot_data['batch_size'],plot_data['avgloss'],'b',linewidth=1.5)
        plt.xlabel('Minbatch number')
        plt.ylabel('Loss')
        plt.title('Minibatch run vs Training loss')
        plt.show()
    
    
    
    
    
    

    在这里插入图片描述

    在这里插入图片描述

  • 相关阅读:
    C# .Net基础知识点解答
    依赖注入框架Autofac的简单使用
    Linq表达式、Lambda表达式你更喜欢哪个?
    C#抽象类、接口、虚函数和抽象函数
    MVC面试问题与答案
    并发 并行 同步 异步 多线程的区别
    .Net中的控制翻转和依赖注入
    解析ASP.NET WebForm和Mvc开发的区别
    测试与代码质量
    netty 同步调用
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13309446.html
Copyright © 2011-2022 走看看