zoukankan      html  css  js  c++  java
  • tensorflow拟合随机生成的三维数据【学习笔记】

    平台信息:
    PC:ubuntu18.04、i5、anaconda2、cuda9.0、cudnn7.0.5、tensorflow1.10、GTX1060

    作者:庄泽彬(欢迎转载,请注明作者)

    说明:感谢tensorflow社区,本文是在社区的学习笔记,生成随机的三维数据,之后用平面去拟合。

    相关代码:

     1 #!/usr/bin/env python2
     2 # -*- coding: utf-8 -*-
     3 """
     4 Created on Thu Oct 11 19:54:15 2018
     5 
     6 @author: zhuang
     7 """
     8 
     9 import tensorflow as tf
    10 import numpy as np
    11 
    12 #生成随机数
    13 x_data = np.float32(np.random.rand(2,100))
    14 y_data = np.dot([0.100,0.200],x_data) + 0.300
    15 
    16 
    17 # 初始化参数
    18 b = tf.Variable(tf.zeros([1]))
    19 # w为1x2的矩阵,在-1.0到1.0之间均匀分布
    20 w = tf.Variable(tf.random_uniform([1,2],-1.0,1.0))
    21 y = tf.matmul(w,x_data) + b
    22 
    23 # 使用最小化方差进行梯度下降,来不断更新参数,学习率设置为0.5
    24 loss = tf.reduce_mean(tf.square(y-y_data))
    25 optimizer = tf.train.GradientDescentOptimizer(0.5)
    26 train = optimizer.minimize(loss)
    27 
    28 #init = tf.initialize_all_variables()
    29 #新版本的tensorflow使用下面的接口,老版本使用上面的接口
    30 init = tf.global_variables_initializer()
    31 
    32 sess = tf.Session()
    33 sess.run(init)
    34 
    35 #进行拟合找到适合的参数
    36 for step in xrange(0,201):
    37     sess.run(train)
    38     if (step) % 20 == 0:
    39         print step,sess.run(w),sess.run(b)
    40         

    实验结果:

     1 runfile('/home/zhuang/project/1-AI/My_AI_Study_Project/3-tensorflow/005-test.py', wdir='/home/zhuang/project/1-AI/My_AI_Study_Project/3-tensorflow')
     2 0 [[0.39894426 0.3333286 ]] [0.14586714]
     3 20 [[0.16806586 0.26403958]] [0.22699882]
     4 40 [[0.12101775 0.22725435]] [0.27309927]
     5 60 [[0.10728491 0.21063562]] [0.28998315]
     6 80 [[0.10265137 0.20403926]] [0.2962562]
     7 100 [[0.10098286 0.20152006]] [0.29859895]
     8 120 [[0.10036676 0.20057023]] [0.29947543]
     9 140 [[0.10013718 0.20021366]] [0.29980358]
    10 160 [[0.10005137 0.20008004]] [0.29992643]
    11 180 [[0.10001925 0.20002998]] [0.29997244]
    12 200 [[0.10000721 0.20001122]] [0.29998967]

    我们的目标方程y_data = np.dot([0.100,0.200],x_data) + 0.300,经过200次的训练更新w,b参数为[[0.10000721 0.20001122]] [0.29998967],非常接近我们方程的参数。

     

  • 相关阅读:
    第十四章:(2)Spring Boot 与 分布式 之 Dubbo + Zookeeper
    第十四章:(1)Spring Boot 与 分布式 之 分布式介绍
    第九章:Redis 的Java客户端Jedis
    第十三章:(2)Spring Boot 与 安全 之 SpringBoot + SpringSecurity + Thymeleaf
    第八章:(1)Redis 的复制(Master/Slave)
    java学习
    周末总结4
    java
    Cheatsheet: 2012 12.17 ~ 12.31
    Cheatsheet: 2012 10.01 ~ 10.07
  • 原文地址:https://www.cnblogs.com/zzb-Dream-90Time/p/9774858.html
Copyright © 2011-2022 走看看