zoukankan      html  css  js  c++  java
  • Tensorflow梯度下降应用

    import tensorflow as tf
    import numpy as np

    #使用numpy生成随机点
    x_data = np.random.rand(100)
    y_data = x_data*0.1 + 0.2

    #构造一个线性模型
    b = tf.Variable(0.0)
    k = tf.Variable(0.0)
    y = k*x_data+b

    #二次代价函数
    loss = tf.reduce_mean(tf.square(y_data-y))#误差平方求平均值
    #定义一个梯度下降来进行训练的优化器
    optimizer = tf.train.GradientDescentOptimizer(0.2)

    #最小化代价函数
    train = optimizer.minimize(loss)
    #初始化变量
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
      sess.run(init)
        for index in range(201):
          sess.run(train)
          if index%10==0:
            print(index,sess.run([k,b]))

    ###########输出

    0 [0.058540713, 0.10185367]
    10 [0.10913987, 0.19464658]
    20 [0.10734161, 0.19575559]
    30 [0.10587782, 0.19660187]
    40 [0.10470589, 0.19727939]
    50 [0.10376761, 0.19782184]
    60 [0.10301641, 0.19825613]
    70 [0.10241497, 0.19860384]
    80 [0.10193346, 0.19888222]
    90 [0.10154796, 0.19910508]
    100 [0.10123933, 0.19928351]
    110 [0.10099223, 0.19942637]
    120 [0.10079438, 0.19954075]
    130 [0.10063599, 0.19963232]
    140 [0.10050918, 0.19970562]
    150 [0.10040767, 0.19976433]
    160 [0.10032637, 0.19981132]
    170 [0.1002613, 0.19984894]
    180 [0.1002092, 0.19987905]
    190 [0.1001675, 0.19990316]
    200 [0.10013408, 0.19992249]
  • 相关阅读:
    ssrf简介
    Mysql 命令 load data infile
    基于约束的SQL注入笔记
    ms17-010
    thinkphp5.0&5.1命令执行 和 thinkphp3.2.3sql注入
    抓取分析菜刀流量
    lamp环境的搭建
    php伪协议
    LeetCode-336 Palindrome Pairs
    LeetCode-335 Self Crossing
  • 原文地址:https://www.cnblogs.com/herd/p/9457418.html
Copyright © 2011-2022 走看看